pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
- {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/METADATA +11 -9
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
- scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +224 -223
- scripts/postprocess.py +283 -163
- scripts/preprocess.py +11 -8
- scripts/preprocessor_gui.py +10 -9
- scripts/refine_matches.py +626 -0
- tests/preprocessing/test_frequency_filters.py +9 -4
- tests/test_analyzer.py +143 -138
- tests/test_matching_cli.py +85 -29
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +25 -17
- tme/analyzer/aggregation.py +385 -220
- tme/analyzer/base.py +138 -0
- tme/analyzer/peaks.py +150 -88
- tme/analyzer/proxy.py +122 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +4 -3
- tme/backends/cupy_backend.py +4 -13
- tme/backends/jax_backend.py +6 -8
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +7 -5
- tme/backends/pytorch_backend.py +14 -4
- tme/cli.py +126 -0
- tme/density.py +4 -3
- tme/filters/__init__.py +1 -1
- tme/filters/_utils.py +4 -3
- tme/filters/bandpass.py +6 -4
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +426 -214
- tme/filters/reconstruction.py +58 -28
- tme/filters/wedge.py +139 -61
- tme/filters/whitening.py +36 -36
- tme/matching_data.py +4 -3
- tme/matching_exhaustive.py +17 -16
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +4 -3
- tme/matching_utils.py +6 -4
- tme/memory.py +4 -3
- tme/orientations.py +9 -6
- tme/parser.py +5 -4
- tme/preprocessor.py +4 -3
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
scripts/match_template.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
#!python3
|
2
|
-
"""
|
2
|
+
"""CLI for basic pyTME template matching functions.
|
3
3
|
|
4
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
5
|
|
6
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
import os
|
9
9
|
import argparse
|
@@ -18,61 +18,31 @@ from tempfile import gettempdir
|
|
18
18
|
import numpy as np
|
19
19
|
|
20
20
|
from tme.backends import backend as be
|
21
|
-
from tme import Density, __version__
|
21
|
+
from tme import Density, __version__, Orientations
|
22
22
|
from tme.matching_utils import scramble_phases, write_pickle
|
23
23
|
from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
|
24
24
|
from tme.rotations import (
|
25
25
|
get_cone_rotations,
|
26
26
|
get_rotation_matrices,
|
27
|
+
euler_to_rotationmatrix,
|
27
28
|
)
|
28
29
|
from tme.matching_data import MatchingData
|
29
30
|
from tme.analyzer import (
|
30
31
|
MaxScoreOverRotations,
|
31
32
|
PeakCallerMaximumFilter,
|
33
|
+
MaxScoreOverRotationsConstrained,
|
32
34
|
)
|
33
35
|
from tme.filters import (
|
34
36
|
CTF,
|
35
37
|
Wedge,
|
36
38
|
Compose,
|
37
39
|
BandPassFilter,
|
40
|
+
CTFReconstructed,
|
38
41
|
WedgeReconstructed,
|
39
42
|
ReconstructFromTilt,
|
40
43
|
LinearWhiteningFilter,
|
41
44
|
)
|
42
|
-
|
43
|
-
|
44
|
-
def get_func_fullname(func) -> str:
|
45
|
-
"""Returns the full name of the given function, including its module."""
|
46
|
-
return f"<function '{func.__module__}.{func.__name__}'>"
|
47
|
-
|
48
|
-
|
49
|
-
def print_block(name: str, data: dict, label_width=20) -> None:
|
50
|
-
"""Prints a formatted block of information."""
|
51
|
-
print(f"\n> {name}")
|
52
|
-
for key, value in data.items():
|
53
|
-
if isinstance(value, np.ndarray):
|
54
|
-
value = value.shape
|
55
|
-
formatted_value = str(value)
|
56
|
-
print(f" - {key + ':':<{label_width}} {formatted_value}")
|
57
|
-
|
58
|
-
|
59
|
-
def print_entry() -> None:
|
60
|
-
width = 80
|
61
|
-
text = f" pytme v{__version__} "
|
62
|
-
padding_total = width - len(text) - 2
|
63
|
-
padding_left = padding_total // 2
|
64
|
-
padding_right = padding_total - padding_left
|
65
|
-
|
66
|
-
print("*" * width)
|
67
|
-
print(f"*{ ' ' * padding_left }{text}{ ' ' * padding_right }*")
|
68
|
-
print("*" * width)
|
69
|
-
|
70
|
-
|
71
|
-
def check_positive(value):
|
72
|
-
ivalue = float(value)
|
73
|
-
if ivalue <= 0:
|
74
|
-
raise argparse.ArgumentTypeError("%s is an invalid positive float." % value)
|
75
|
-
return ivalue
|
45
|
+
from tme.cli import get_func_fullname, print_block, print_entry, check_positive
|
76
46
|
|
77
47
|
|
78
48
|
def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
@@ -118,6 +88,14 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
|
|
118
88
|
|
119
89
|
|
120
90
|
def parse_rotation_logic(args, ndim):
|
91
|
+
if args.particle_diameter is not None:
|
92
|
+
resolution = Density.from_file(args.target, use_memmap=True)
|
93
|
+
resolution = 360 * np.maximum(
|
94
|
+
np.max(2 * resolution.sampling_rate),
|
95
|
+
args.lowpass if args.lowpass is not None else 0,
|
96
|
+
)
|
97
|
+
args.angular_sampling = resolution / (3.14159265358979 * args.particle_diameter)
|
98
|
+
|
121
99
|
if args.angular_sampling is not None:
|
122
100
|
rotations = get_rotation_matrices(
|
123
101
|
angular_sampling=args.angular_sampling,
|
@@ -178,123 +156,72 @@ def compute_schedule(
|
|
178
156
|
|
179
157
|
|
180
158
|
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
181
|
-
needs_reconstruction = False
|
182
159
|
template_filter, target_filter = [], []
|
160
|
+
|
161
|
+
wedge = None
|
183
162
|
if args.tilt_angles is not None:
|
184
|
-
needs_reconstruction = args.tilt_weighting is not None
|
185
163
|
try:
|
186
164
|
wedge = Wedge.from_file(args.tilt_angles)
|
187
165
|
wedge.weight_type = args.tilt_weighting
|
188
|
-
if args.tilt_weighting in ("angle", None)
|
166
|
+
if args.tilt_weighting in ("angle", None):
|
189
167
|
wedge = WedgeReconstructed(
|
190
168
|
angles=wedge.angles,
|
191
169
|
weight_wedge=args.tilt_weighting == "angle",
|
192
|
-
opening_axis=args.wedge_axes[0],
|
193
|
-
tilt_axis=args.wedge_axes[1],
|
194
170
|
)
|
195
|
-
except FileNotFoundError:
|
196
|
-
tilt_step, create_continuous_wedge = None, True
|
171
|
+
except (FileNotFoundError, AttributeError):
|
197
172
|
tilt_start, tilt_stop = args.tilt_angles.split(",")
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
tilt_step = float(tilt_step)
|
205
|
-
tilt_angles = np.arange(
|
206
|
-
-tilt_start, tilt_stop + tilt_step, tilt_step
|
207
|
-
).tolist()
|
208
|
-
|
209
|
-
if args.tilt_weighting is not None and tilt_step is None:
|
210
|
-
raise ValueError(
|
211
|
-
"Tilt weighting is not supported for continuous wedges."
|
212
|
-
)
|
213
|
-
if args.tilt_weighting not in ("angle", None):
|
214
|
-
raise ValueError(
|
215
|
-
"Tilt weighting schemes other than 'angle' or 'None' require "
|
216
|
-
"a specification of electron doses via --tilt_angles."
|
217
|
-
)
|
218
|
-
|
219
|
-
wedge = Wedge(
|
220
|
-
angles=tilt_angles,
|
221
|
-
opening_axis=args.wedge_axes[0],
|
222
|
-
tilt_axis=args.wedge_axes[1],
|
223
|
-
shape=None,
|
224
|
-
weight_type=None,
|
225
|
-
weights=np.ones_like(tilt_angles),
|
173
|
+
tilt_start, tilt_stop = abs(float(tilt_start)), abs(float(tilt_stop))
|
174
|
+
wedge = WedgeReconstructed(
|
175
|
+
angles=(tilt_start, tilt_stop),
|
176
|
+
create_continuous_wedge=True,
|
177
|
+
weight_wedge=False,
|
178
|
+
reconstruction_filter=args.reconstruction_filter,
|
226
179
|
)
|
227
|
-
if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
|
228
|
-
wedge = WedgeReconstructed(
|
229
|
-
angles=tilt_angles,
|
230
|
-
weight_wedge=args.tilt_weighting == "angle",
|
231
|
-
create_continuous_wedge=create_continuous_wedge,
|
232
|
-
reconstruction_filter=args.reconstruction_filter,
|
233
|
-
opening_axis=args.wedge_axes[0],
|
234
|
-
tilt_axis=args.wedge_axes[1],
|
235
|
-
)
|
236
|
-
wedge_target = WedgeReconstructed(
|
237
|
-
angles=(np.abs(np.min(tilt_angles)), np.abs(np.max(tilt_angles))),
|
238
|
-
weight_wedge=False,
|
239
|
-
create_continuous_wedge=True,
|
240
|
-
opening_axis=args.wedge_axes[0],
|
241
|
-
tilt_axis=args.wedge_axes[1],
|
242
|
-
)
|
243
|
-
target_filter.append(wedge_target)
|
244
180
|
|
245
|
-
|
181
|
+
wedge_target = WedgeReconstructed(
|
182
|
+
angles=wedge.angles,
|
183
|
+
weight_wedge=False,
|
184
|
+
create_continuous_wedge=True,
|
185
|
+
opening_axis=args.wedge_axes[0],
|
186
|
+
tilt_axis=args.wedge_axes[1],
|
187
|
+
)
|
188
|
+
wedge.opening_axis = args.wedge_axes[0]
|
189
|
+
wedge.tilt_axis = args.wedge_axes[1]
|
190
|
+
|
191
|
+
target_filter.append(wedge_target)
|
246
192
|
template_filter.append(wedge)
|
247
|
-
if not isinstance(wedge, WedgeReconstructed):
|
248
|
-
reconstruction_filter = ReconstructFromTilt(
|
249
|
-
reconstruction_filter=args.reconstruction_filter,
|
250
|
-
interpolation_order=args.reconstruction_interpolation_order,
|
251
|
-
)
|
252
|
-
template_filter.append(reconstruction_filter)
|
253
193
|
|
194
|
+
args.ctf_file is not None
|
254
195
|
if args.ctf_file is not None or args.defocus is not None:
|
255
|
-
|
256
|
-
if args.ctf_file is not None:
|
196
|
+
try:
|
257
197
|
ctf = CTF.from_file(args.ctf_file)
|
198
|
+
if (len(ctf.angles) == 0) and wedge is None:
|
199
|
+
raise ValueError(
|
200
|
+
"You requested to specify the CTF per tilt, but did not specify "
|
201
|
+
"tilt angles via --tilt_angles or --ctf_file (Warp/M XML format). "
|
202
|
+
)
|
203
|
+
if len(ctf.angles) == 0:
|
204
|
+
ctf.angles = wedge.angles
|
205
|
+
|
258
206
|
n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
|
259
|
-
if n_tilts_ctfs != n_tils_angles:
|
207
|
+
if (n_tilts_ctfs != n_tils_angles) and isinstance(wedge, Wedge):
|
260
208
|
raise ValueError(
|
261
|
-
f"CTF file contains {n_tilts_ctfs}
|
209
|
+
f"CTF file contains {n_tilts_ctfs} tilt, but match_template "
|
262
210
|
f"recieved {n_tils_angles} tilt angles. Expected one angle "
|
263
|
-
"per
|
211
|
+
"per tilt."
|
264
212
|
)
|
265
|
-
|
266
|
-
|
267
|
-
ctf.
|
268
|
-
|
269
|
-
|
270
|
-
ctf = CTF(
|
271
|
-
defocus_x=[args.defocus],
|
272
|
-
phase_shift=[args.phase_shift],
|
273
|
-
defocus_y=None,
|
274
|
-
angles=[0],
|
275
|
-
shape=None,
|
276
|
-
return_real_fourier=True,
|
277
|
-
)
|
213
|
+
|
214
|
+
except (FileNotFoundError, AttributeError):
|
215
|
+
ctf = CTFReconstructed(defocus_x=args.defocus, phase_shift=args.phase_shift)
|
216
|
+
|
217
|
+
ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
|
278
218
|
ctf.sampling_rate = template.sampling_rate
|
279
219
|
ctf.flip_phase = args.no_flip_phase
|
280
220
|
ctf.amplitude_contrast = args.amplitude_contrast
|
281
221
|
ctf.spherical_aberration = args.spherical_aberration
|
282
222
|
ctf.acceleration_voltage = args.acceleration_voltage * 1e3
|
283
223
|
ctf.correct_defocus_gradient = args.correct_defocus_gradient
|
284
|
-
|
285
|
-
if not needs_reconstruction:
|
286
|
-
template_filter.append(ctf)
|
287
|
-
elif isinstance(template_filter[-1], ReconstructFromTilt):
|
288
|
-
template_filter.insert(-1, ctf)
|
289
|
-
else:
|
290
|
-
template_filter.insert(0, ctf)
|
291
|
-
template_filter.insert(
|
292
|
-
1,
|
293
|
-
ReconstructFromTilt(
|
294
|
-
reconstruction_filter=args.reconstruction_filter,
|
295
|
-
interpolation_order=args.reconstruction_interpolation_order,
|
296
|
-
),
|
297
|
-
)
|
224
|
+
template_filter.append(ctf)
|
298
225
|
|
299
226
|
if args.lowpass or args.highpass is not None:
|
300
227
|
lowpass, highpass = args.lowpass, args.highpass
|
@@ -329,11 +256,31 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
329
256
|
template_filter.append(whitening_filter)
|
330
257
|
target_filter.append(whitening_filter)
|
331
258
|
|
332
|
-
|
259
|
+
rec_filt = (Wedge, CTF)
|
260
|
+
needs_reconstruction = sum(type(x) in rec_filt for x in template_filter)
|
261
|
+
if needs_reconstruction > 0 and args.reconstruction_filter is None:
|
333
262
|
warnings.warn(
|
334
|
-
"Consider using a --reconstruction_filter such as '
|
263
|
+
"Consider using a --reconstruction_filter such as 'ram-lak' or 'ramp' "
|
264
|
+
"to avoid artifacts from reconstruction using weighted backprojection."
|
335
265
|
)
|
336
266
|
|
267
|
+
template_filter = sorted(
|
268
|
+
template_filter, key=lambda x: type(x) in rec_filt, reverse=True
|
269
|
+
)
|
270
|
+
if needs_reconstruction > 0:
|
271
|
+
relevant_filters = [x for x in template_filter if type(x) in rec_filt]
|
272
|
+
if len(relevant_filters) == 0:
|
273
|
+
raise ValueError("Filters require ")
|
274
|
+
|
275
|
+
reconstruction_filter = ReconstructFromTilt(
|
276
|
+
reconstruction_filter=args.reconstruction_filter,
|
277
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
278
|
+
angles=relevant_filters[0].angles,
|
279
|
+
opening_axis=args.wedge_axes[0],
|
280
|
+
tilt_axis=args.wedge_axes[1],
|
281
|
+
)
|
282
|
+
template_filter.insert(needs_reconstruction, reconstruction_filter)
|
283
|
+
|
337
284
|
template_filter = Compose(template_filter) if len(template_filter) else None
|
338
285
|
target_filter = Compose(target_filter) if len(target_filter) else None
|
339
286
|
if args.no_filter_target:
|
@@ -342,6 +289,10 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
|
|
342
289
|
return template_filter, target_filter
|
343
290
|
|
344
291
|
|
292
|
+
def _format_sampling(arr, decimals: int = 2):
|
293
|
+
return tuple(round(float(x), decimals) for x in arr)
|
294
|
+
|
295
|
+
|
345
296
|
def parse_args():
|
346
297
|
parser = argparse.ArgumentParser(
|
347
298
|
description="Perform template matching.",
|
@@ -406,6 +357,40 @@ def parse_args():
|
|
406
357
|
help="Phase scramble the template to generate a noise score background.",
|
407
358
|
)
|
408
359
|
|
360
|
+
sampling_group = parser.add_argument_group("Sampling")
|
361
|
+
sampling_group.add_argument(
|
362
|
+
"--orientations",
|
363
|
+
dest="orientations",
|
364
|
+
default=None,
|
365
|
+
required=False,
|
366
|
+
help="Path to a file readable via Orientations.from_file containing "
|
367
|
+
"translations and rotations of candidate peaks to refine.",
|
368
|
+
)
|
369
|
+
sampling_group.add_argument(
|
370
|
+
"--orientations_scaling",
|
371
|
+
required=False,
|
372
|
+
type=float,
|
373
|
+
default=1.0,
|
374
|
+
help="Scaling factor to map candidate translations onto the target. "
|
375
|
+
"Assuming coordinates are in Å and target sampling rate are 3Å/voxel, "
|
376
|
+
"the corresponding orientations_scaling would be 3.",
|
377
|
+
)
|
378
|
+
sampling_group.add_argument(
|
379
|
+
"--orientations_cone",
|
380
|
+
required=False,
|
381
|
+
type=float,
|
382
|
+
default=20.0,
|
383
|
+
help="Accept orientations within specified cone angle of each orientation.",
|
384
|
+
)
|
385
|
+
sampling_group.add_argument(
|
386
|
+
"--orientations_uncertainty",
|
387
|
+
required=False,
|
388
|
+
type=str,
|
389
|
+
default="10",
|
390
|
+
help="Accept translations within the specified radius of each orientation. "
|
391
|
+
"Can be a single value or comma-separated string for per-axis uncertainty.",
|
392
|
+
)
|
393
|
+
|
409
394
|
scoring_group = parser.add_argument_group("Scoring")
|
410
395
|
scoring_group.add_argument(
|
411
396
|
"-s",
|
@@ -435,6 +420,13 @@ def parse_args():
|
|
435
420
|
help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
|
436
421
|
"narrow interval around a known orientation, e.g. for surface oversampling.",
|
437
422
|
)
|
423
|
+
angular_exclusive.add_argument(
|
424
|
+
"--particle_diameter",
|
425
|
+
dest="particle_diameter",
|
426
|
+
type=check_positive,
|
427
|
+
default=None,
|
428
|
+
help="Particle diameter in units of sampling rate.",
|
429
|
+
)
|
438
430
|
angular_group.add_argument(
|
439
431
|
"--cone_axis",
|
440
432
|
dest="cone_axis",
|
@@ -517,6 +509,7 @@ def parse_args():
|
|
517
509
|
computation_group.add_argument(
|
518
510
|
"-r",
|
519
511
|
"--ram",
|
512
|
+
"--memory",
|
520
513
|
dest="memory",
|
521
514
|
required=False,
|
522
515
|
type=int,
|
@@ -529,7 +522,7 @@ def parse_args():
|
|
529
522
|
required=False,
|
530
523
|
type=float,
|
531
524
|
default=0.85,
|
532
|
-
help="Fraction of available memory to be used. Ignored if --
|
525
|
+
help="Fraction of available memory to be used. Ignored if --memory is set.",
|
533
526
|
)
|
534
527
|
computation_group.add_argument(
|
535
528
|
"--temp_directory",
|
@@ -540,7 +533,7 @@ def parse_args():
|
|
540
533
|
computation_group.add_argument(
|
541
534
|
"--backend",
|
542
535
|
dest="backend",
|
543
|
-
default=
|
536
|
+
default=be._backend_name,
|
544
537
|
choices=be.available_backends(),
|
545
538
|
help="[Expert] Overwrite default computation backend.",
|
546
539
|
)
|
@@ -575,7 +568,8 @@ def parse_args():
|
|
575
568
|
required=False,
|
576
569
|
default="sampling_rate",
|
577
570
|
choices=["sampling_rate", "voxel", "frequency"],
|
578
|
-
help="How values passed to --lowpass and --highpass should be interpreted. "
|
571
|
+
help="How values passed to --lowpass and --highpass should be interpreted. "
|
572
|
+
"Defaults to unit of sampling_rate, e.g., 40 Angstrom.",
|
579
573
|
)
|
580
574
|
filter_group.add_argument(
|
581
575
|
"--whiten_spectrum",
|
@@ -589,9 +583,9 @@ def parse_args():
|
|
589
583
|
dest="wedge_axes",
|
590
584
|
type=str,
|
591
585
|
required=False,
|
592
|
-
default=
|
593
|
-
help="Indices of wedge opening and tilt axis, e.g
|
594
|
-
"
|
586
|
+
default="2,0",
|
587
|
+
help="Indices of projection (wedge opening) and tilt axis, e.g., '2,0' "
|
588
|
+
"for the typical projection over z and tilting over the x-axis.",
|
595
589
|
)
|
596
590
|
filter_group.add_argument(
|
597
591
|
"--tilt_angles",
|
@@ -599,10 +593,12 @@ def parse_args():
|
|
599
593
|
type=str,
|
600
594
|
required=False,
|
601
595
|
default=None,
|
602
|
-
help="Path to a
|
603
|
-
"
|
604
|
-
"
|
605
|
-
"
|
596
|
+
help="Path to a file specifying tilt angles. This can be a Warp/M XML file, "
|
597
|
+
"a tomostar STAR file, a tab-separated file with column name 'angles', or a "
|
598
|
+
"single column file without header. Exposure will be taken from the input file "
|
599
|
+
", if you are using a tab-separated file, the column names 'angles' and "
|
600
|
+
"'weights' need to be present. It is also possible to specify a continuous "
|
601
|
+
"wedge mask using e.g., -50,45.",
|
606
602
|
)
|
607
603
|
filter_group.add_argument(
|
608
604
|
"--tilt_weighting",
|
@@ -649,8 +645,9 @@ def parse_args():
|
|
649
645
|
type=str,
|
650
646
|
required=False,
|
651
647
|
default=None,
|
652
|
-
help="Path to a file with CTF parameters
|
653
|
-
"
|
648
|
+
help="Path to a file with CTF parameters. This can be a Warp/M XML file "
|
649
|
+
"a GCTF/Relion STAR file, or the output of CTFFIND4. If the file does not "
|
650
|
+
"specify tilt angles, the angles specified with --tilt_angles are used.",
|
654
651
|
)
|
655
652
|
ctf_group.add_argument(
|
656
653
|
"--defocus",
|
@@ -658,8 +655,8 @@ def parse_args():
|
|
658
655
|
type=float,
|
659
656
|
required=False,
|
660
657
|
default=None,
|
661
|
-
help="Defocus in units of sampling rate (typically Ångstrom). "
|
662
|
-
"Superseded by --ctf_file.",
|
658
|
+
help="Defocus in units of sampling rate (typically Ångstrom), e.g., 30000 "
|
659
|
+
"for a defocus of 3 micrometer. Superseded by --ctf_file.",
|
663
660
|
)
|
664
661
|
ctf_group.add_argument(
|
665
662
|
"--phase_shift",
|
@@ -745,7 +742,8 @@ def parse_args():
|
|
745
742
|
dest="use_mixed_precision",
|
746
743
|
action="store_true",
|
747
744
|
default=False,
|
748
|
-
help="Use float16 for real values operations where possible."
|
745
|
+
help="Use float16 for real values operations where possible. Not supported "
|
746
|
+
"for jax backend.",
|
749
747
|
)
|
750
748
|
performance_group.add_argument(
|
751
749
|
"--use_memmap",
|
@@ -773,8 +771,8 @@ def parse_args():
|
|
773
771
|
help="Perform peak calling instead of score aggregation.",
|
774
772
|
)
|
775
773
|
analyzer_group.add_argument(
|
776
|
-
"--
|
777
|
-
dest="
|
774
|
+
"--num_peaks",
|
775
|
+
dest="num_peaks",
|
778
776
|
action="store_true",
|
779
777
|
default=1000,
|
780
778
|
help="Number of peaks to call, 1000 by default.",
|
@@ -794,21 +792,14 @@ def parse_args():
|
|
794
792
|
f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
|
795
793
|
)
|
796
794
|
|
797
|
-
gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
798
795
|
if args.gpu_indices is not None:
|
799
796
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
|
800
797
|
|
801
798
|
if args.use_gpu:
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
"Assuming device 0.",
|
807
|
-
)
|
808
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
809
|
-
args.gpu_indices = [
|
810
|
-
int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
811
|
-
]
|
799
|
+
warnings.warn(
|
800
|
+
"The use_gpu flag is no longer required and automatically "
|
801
|
+
"determined based on the selected backend."
|
802
|
+
)
|
812
803
|
|
813
804
|
if args.tilt_angles is not None:
|
814
805
|
if args.wedge_axes is None:
|
@@ -825,6 +816,13 @@ def parse_args():
|
|
825
816
|
if args.wedge_axes is not None:
|
826
817
|
args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
|
827
818
|
|
819
|
+
if args.orientations is not None:
|
820
|
+
orientations = Orientations.from_file(args.orientations)
|
821
|
+
orientations.translations = np.divide(
|
822
|
+
orientations.translations, args.orientations_scaling
|
823
|
+
)
|
824
|
+
args.orientations = orientations
|
825
|
+
|
828
826
|
return args
|
829
827
|
|
830
828
|
|
@@ -864,7 +862,7 @@ def main():
|
|
864
862
|
name="Target",
|
865
863
|
data={
|
866
864
|
"Initial Shape": initial_shape,
|
867
|
-
"Sampling Rate":
|
865
|
+
"Sampling Rate": _format_sampling(target.sampling_rate),
|
868
866
|
"Final Shape": target.shape,
|
869
867
|
},
|
870
868
|
)
|
@@ -874,7 +872,7 @@ def main():
|
|
874
872
|
name="Target Mask",
|
875
873
|
data={
|
876
874
|
"Initial Shape": initial_shape,
|
877
|
-
"Sampling Rate":
|
875
|
+
"Sampling Rate": _format_sampling(target_mask.sampling_rate),
|
878
876
|
"Final Shape": target_mask.shape,
|
879
877
|
},
|
880
878
|
)
|
@@ -887,7 +885,7 @@ def main():
|
|
887
885
|
name="Template",
|
888
886
|
data={
|
889
887
|
"Initial Shape": initial_shape,
|
890
|
-
"Sampling Rate":
|
888
|
+
"Sampling Rate": _format_sampling(template.sampling_rate),
|
891
889
|
"Final Shape": template.shape,
|
892
890
|
},
|
893
891
|
)
|
@@ -919,7 +917,7 @@ def main():
|
|
919
917
|
name="Template Mask",
|
920
918
|
data={
|
921
919
|
"Inital Shape": initial_shape,
|
922
|
-
"Sampling Rate":
|
920
|
+
"Sampling Rate": _format_sampling(template_mask.sampling_rate),
|
923
921
|
"Final Shape": template_mask.shape,
|
924
922
|
},
|
925
923
|
)
|
@@ -930,65 +928,71 @@ def main():
|
|
930
928
|
template.data, noise_proportion=1.0, normalize_power=False
|
931
929
|
)
|
932
930
|
|
933
|
-
|
934
|
-
available_backends = be.available_backends()
|
935
|
-
if args.backend is not None:
|
936
|
-
req_backend = args.backend
|
937
|
-
if req_backend not in available_backends:
|
938
|
-
raise ValueError("Requested backend is not available.")
|
939
|
-
available_backends = [req_backend]
|
940
|
-
|
941
|
-
be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
|
942
|
-
if args.use_gpu:
|
943
|
-
args.cores = len(args.gpu_indices)
|
944
|
-
be_selection = ("pytorch", "cupy", "jax")
|
945
|
-
if args.use_mixed_precision:
|
946
|
-
be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
|
947
|
-
|
948
|
-
available_backends = [x for x in available_backends if x in be_selection]
|
931
|
+
callback_class = MaxScoreOverRotations
|
949
932
|
if args.peak_calling:
|
950
|
-
|
951
|
-
available_backends.remove("jax")
|
952
|
-
if args.use_gpu and "pytorch" in available_backends:
|
953
|
-
available_backends = ("pytorch",)
|
933
|
+
callback_class = PeakCallerMaximumFilter
|
954
934
|
|
955
|
-
|
956
|
-
|
957
|
-
# args.interpolation_order = 1
|
958
|
-
# available_backends = ["jax"]
|
935
|
+
if args.orientations is not None:
|
936
|
+
callback_class = MaxScoreOverRotationsConstrained
|
959
937
|
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
|
969
|
-
|
970
|
-
if args.use_mixed_precision:
|
971
|
-
be.change_backend(
|
972
|
-
backend_name=pref,
|
973
|
-
default_dtype=be._array_backend.float16,
|
974
|
-
complex_dtype=be._array_backend.complex64,
|
975
|
-
default_dtype_int=be._array_backend.int16,
|
976
|
-
)
|
977
|
-
break
|
938
|
+
# Determine suitable backend for the selected operation
|
939
|
+
available_backends = be.available_backends()
|
940
|
+
if args.backend not in available_backends:
|
941
|
+
raise ValueError("Requested backend is not available.")
|
942
|
+
if args.backend == "jax" and callback_class != MaxScoreOverRotations:
|
943
|
+
raise ValueError(
|
944
|
+
"Jax backend only supports the MaxScoreOverRotations analyzer."
|
945
|
+
)
|
978
946
|
|
979
|
-
if
|
947
|
+
if args.interpolation_order == 3 and args.backend in ("jax", "pytorch"):
|
980
948
|
warnings.warn(
|
981
|
-
"
|
949
|
+
"Jax and pytorch do not support interpolation order 3, setting it to 1."
|
982
950
|
)
|
983
951
|
args.interpolation_order = 1
|
984
952
|
|
953
|
+
if args.backend in ("pytorch", "cupy", "jax"):
|
954
|
+
gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
955
|
+
if gpu_devices is None:
|
956
|
+
warnings.warn(
|
957
|
+
"No GPU indices provided and CUDA_VISIBLE_DEVICES is not set. "
|
958
|
+
"Assuming device 0.",
|
959
|
+
)
|
960
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
961
|
+
else:
|
962
|
+
args.cores = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
|
963
|
+
args.gpu_indices = [
|
964
|
+
int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
965
|
+
]
|
966
|
+
|
967
|
+
# Finally set the desired backend
|
968
|
+
device = "cuda"
|
969
|
+
be.change_backend(args.backend)
|
970
|
+
if args.backend == "pytorch":
|
971
|
+
try:
|
972
|
+
be.change_backend("pytorch", device=device)
|
973
|
+
# Trigger exception if not compiled with device
|
974
|
+
be.get_available_memory()
|
975
|
+
except Exception as e:
|
976
|
+
print(e)
|
977
|
+
device = "cpu"
|
978
|
+
be.change_backend("pytorch", device=device)
|
979
|
+
if args.use_mixed_precision:
|
980
|
+
be.change_backend(
|
981
|
+
backend_name=args.backend,
|
982
|
+
default_dtype=be._array_backend.float16,
|
983
|
+
complex_dtype=be._array_backend.complex64,
|
984
|
+
default_dtype_int=be._array_backend.int16,
|
985
|
+
device=device,
|
986
|
+
)
|
987
|
+
|
985
988
|
available_memory = be.get_available_memory() * be.device_count()
|
986
989
|
if args.memory is None:
|
987
990
|
args.memory = int(args.memory_scaling * available_memory)
|
988
991
|
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
+
if args.orientations_uncertainty is not None:
|
993
|
+
args.orientations_uncertainty = tuple(
|
994
|
+
int(x) for x in args.orientations_uncertainty.split(",")
|
995
|
+
)
|
992
996
|
|
993
997
|
matching_data = MatchingData(
|
994
998
|
target=target,
|
@@ -1079,10 +1083,19 @@ def main():
|
|
1079
1083
|
|
1080
1084
|
analyzer_args = {
|
1081
1085
|
"score_threshold": args.score_threshold,
|
1082
|
-
"
|
1086
|
+
"num_peaks": args.num_peaks,
|
1083
1087
|
"min_distance": max(template.shape) // 3,
|
1084
1088
|
"use_memmap": args.use_memmap,
|
1085
1089
|
}
|
1090
|
+
if args.orientations is not None:
|
1091
|
+
analyzer_args["reference"] = (0, 0, 1)
|
1092
|
+
analyzer_args["cone_angle"] = args.orientations_cone
|
1093
|
+
analyzer_args["acceptance_radius"] = args.orientations_uncertainty
|
1094
|
+
analyzer_args["positions"] = args.orientations.translations
|
1095
|
+
analyzer_args["rotations"] = euler_to_rotationmatrix(
|
1096
|
+
args.orientations.rotations
|
1097
|
+
)
|
1098
|
+
|
1086
1099
|
print_block(
|
1087
1100
|
name="Analyzer",
|
1088
1101
|
data={"Analyzer": callback_class, **analyzer_args},
|
@@ -1111,18 +1124,6 @@ def main():
|
|
1111
1124
|
)
|
1112
1125
|
|
1113
1126
|
candidates = list(candidates) if candidates is not None else []
|
1114
|
-
if issubclass(callback_class, MaxScoreOverRotations):
|
1115
|
-
if target_mask is not None and args.score != "MCC":
|
1116
|
-
candidates[0] *= target_mask.data
|
1117
|
-
with warnings.catch_warnings():
|
1118
|
-
warnings.simplefilter("ignore", category=UserWarning)
|
1119
|
-
nbytes = be.datatype_bytes(be._float_dtype)
|
1120
|
-
dtype = np.float32 if nbytes == 4 else np.float16
|
1121
|
-
rot_dim = matching_data.rotations.shape[1]
|
1122
|
-
candidates[3] = {
|
1123
|
-
x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
|
1124
|
-
for i, x in candidates[3].items()
|
1125
|
-
}
|
1126
1127
|
candidates.append((target.origin, template.origin, template.sampling_rate, args))
|
1127
1128
|
write_pickle(data=candidates, filename=args.output)
|
1128
1129
|
|