pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
- pytme-0.2.1.dist-info/RECORD +73 -0
- scripts/extract_candidates.py +117 -85
- scripts/match_template.py +473 -140
- scripts/match_template_filters.py +458 -169
- scripts/postprocess.py +107 -49
- scripts/preprocessor_gui.py +4 -1
- scripts/refine_matches.py +364 -160
- tme/__version__.py +1 -1
- tme/analyzer.py +278 -148
- tme/backends/__init__.py +1 -0
- tme/backends/cupy_backend.py +20 -13
- tme/backends/jax_backend.py +218 -0
- tme/backends/matching_backend.py +25 -10
- tme/backends/mlx_backend.py +13 -9
- tme/backends/npfftw_backend.py +22 -12
- tme/backends/pytorch_backend.py +20 -9
- tme/density.py +85 -64
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +86 -60
- tme/matching_exhaustive.py +245 -166
- tme/matching_optimization.py +137 -69
- tme/matching_utils.py +1 -1
- tme/orientations.py +175 -55
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +188 -0
- tme/preprocessing/composable_filter.py +31 -0
- tme/preprocessing/compose.py +51 -0
- tme/preprocessing/frequency_filters.py +378 -0
- tme/preprocessing/tilt_series.py +1017 -0
- tme/preprocessor.py +17 -7
- tme/structure.py +4 -1
- pytme-0.2.0b0.dist-info/RECORD +0 -66
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
@@ -13,13 +13,14 @@ from sys import exit
|
|
13
13
|
from time import time
|
14
14
|
from typing import Tuple
|
15
15
|
from copy import deepcopy
|
16
|
-
from os.path import abspath
|
16
|
+
from os.path import abspath, exists
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
|
20
20
|
from tme import Density, __version__
|
21
21
|
from tme.matching_utils import (
|
22
22
|
get_rotation_matrices,
|
23
|
+
get_rotations_around_vector,
|
23
24
|
compute_parallelization_schedule,
|
24
25
|
euler_from_rotationmatrix,
|
25
26
|
scramble_phases,
|
@@ -32,8 +33,8 @@ from tme.analyzer import (
|
|
32
33
|
MaxScoreOverRotations,
|
33
34
|
PeakCallerMaximumFilter,
|
34
35
|
)
|
35
|
-
from tme.preprocessing import Compose
|
36
36
|
from tme.backends import backend
|
37
|
+
from tme.preprocessing import Compose
|
37
38
|
|
38
39
|
|
39
40
|
def get_func_fullname(func) -> str:
|
@@ -152,6 +153,187 @@ def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
|
|
152
153
|
return True
|
153
154
|
|
154
155
|
|
156
|
+
def parse_rotation_logic(args, ndim):
|
157
|
+
if args.angular_sampling is not None:
|
158
|
+
rotations = get_rotation_matrices(
|
159
|
+
angular_sampling=args.angular_sampling,
|
160
|
+
dim=ndim,
|
161
|
+
use_optimized_set=not args.no_use_optimized_set,
|
162
|
+
)
|
163
|
+
if args.angular_sampling >= 180:
|
164
|
+
rotations = np.eye(ndim).reshape(1, ndim, ndim)
|
165
|
+
return rotations
|
166
|
+
|
167
|
+
if args.axis_sampling is None:
|
168
|
+
args.axis_sampling = args.cone_sampling
|
169
|
+
|
170
|
+
rotations = get_rotations_around_vector(
|
171
|
+
cone_angle=args.cone_angle,
|
172
|
+
cone_sampling=args.cone_sampling,
|
173
|
+
axis_angle=args.axis_angle,
|
174
|
+
axis_sampling=args.axis_sampling,
|
175
|
+
n_symmetry=args.axis_symmetry,
|
176
|
+
)
|
177
|
+
return rotations
|
178
|
+
|
179
|
+
|
180
|
+
# TODO: Think about whether wedge mask should also be added to target
|
181
|
+
# For now leave it at the cost of incorrect upper bound on the scores
|
182
|
+
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
183
|
+
from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
|
184
|
+
from tme.preprocessing.tilt_series import (
|
185
|
+
Wedge,
|
186
|
+
WedgeReconstructed,
|
187
|
+
ReconstructFromTilt,
|
188
|
+
)
|
189
|
+
|
190
|
+
template_filter, target_filter = [], []
|
191
|
+
if args.tilt_angles is not None:
|
192
|
+
try:
|
193
|
+
wedge = Wedge.from_file(args.tilt_angles)
|
194
|
+
wedge.weight_type = args.tilt_weighting
|
195
|
+
if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
|
196
|
+
wedge = WedgeReconstructed(
|
197
|
+
angles=wedge.angles, weight_wedge=args.tilt_weighting == "angle"
|
198
|
+
)
|
199
|
+
except FileNotFoundError:
|
200
|
+
tilt_step, create_continuous_wedge = None, True
|
201
|
+
tilt_start, tilt_stop = args.tilt_angles.split(",")
|
202
|
+
if ":" in tilt_stop:
|
203
|
+
create_continuous_wedge = False
|
204
|
+
tilt_stop, tilt_step = tilt_stop.split(":")
|
205
|
+
tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
|
206
|
+
tilt_angles = (tilt_start, tilt_stop)
|
207
|
+
if tilt_step is not None:
|
208
|
+
tilt_step = float(tilt_step)
|
209
|
+
tilt_angles = np.arange(
|
210
|
+
-tilt_start, tilt_stop + tilt_step, tilt_step
|
211
|
+
).tolist()
|
212
|
+
|
213
|
+
if args.tilt_weighting is not None and tilt_step is None:
|
214
|
+
raise ValueError(
|
215
|
+
"Tilt weighting is not supported for continuous wedges."
|
216
|
+
)
|
217
|
+
if args.tilt_weighting not in ("angle", None):
|
218
|
+
raise ValueError(
|
219
|
+
"Tilt weighting schemes other than 'angle' or 'None' require "
|
220
|
+
"a specification of electron doses via --tilt_angles."
|
221
|
+
)
|
222
|
+
|
223
|
+
wedge = Wedge(
|
224
|
+
angles=tilt_angles,
|
225
|
+
opening_axis=args.wedge_axes[0],
|
226
|
+
tilt_axis=args.wedge_axes[1],
|
227
|
+
shape=None,
|
228
|
+
weight_type=None,
|
229
|
+
weights=np.ones_like(tilt_angles),
|
230
|
+
)
|
231
|
+
if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
|
232
|
+
wedge = WedgeReconstructed(
|
233
|
+
angles=tilt_angles,
|
234
|
+
weight_wedge=args.tilt_weighting == "angle",
|
235
|
+
create_continuous_wedge=create_continuous_wedge,
|
236
|
+
)
|
237
|
+
|
238
|
+
wedge.opening_axis = args.wedge_axes[0]
|
239
|
+
wedge.tilt_axis = args.wedge_axes[1]
|
240
|
+
wedge.sampling_rate = template.sampling_rate
|
241
|
+
template_filter.append(wedge)
|
242
|
+
if not isinstance(wedge, WedgeReconstructed):
|
243
|
+
template_filter.append(
|
244
|
+
ReconstructFromTilt(
|
245
|
+
reconstruction_filter=args.reconstruction_filter,
|
246
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
247
|
+
)
|
248
|
+
)
|
249
|
+
|
250
|
+
if args.ctf_file is not None or args.defocus is not None:
|
251
|
+
from tme.preprocessing.tilt_series import CTF
|
252
|
+
|
253
|
+
needs_reconstruction = True
|
254
|
+
if args.ctf_file is not None:
|
255
|
+
ctf = CTF.from_file(args.ctf_file)
|
256
|
+
n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
|
257
|
+
if n_tilts_ctfs != n_tils_angles:
|
258
|
+
raise ValueError(
|
259
|
+
f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
|
260
|
+
f"recieved {n_tils_angles} tilt angles. Expected one angle "
|
261
|
+
"per micrograph."
|
262
|
+
)
|
263
|
+
ctf.angles = wedge.angles
|
264
|
+
ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
|
265
|
+
else:
|
266
|
+
needs_reconstruction = False
|
267
|
+
ctf = CTF(
|
268
|
+
defocus_x=[args.defocus],
|
269
|
+
phase_shift=[args.phase_shift],
|
270
|
+
defocus_y=None,
|
271
|
+
angles=[0],
|
272
|
+
shape=None,
|
273
|
+
return_real_fourier=True,
|
274
|
+
)
|
275
|
+
ctf.sampling_rate = template.sampling_rate
|
276
|
+
ctf.flip_phase = not args.no_flip_phase
|
277
|
+
ctf.amplitude_contrast = args.amplitude_contrast
|
278
|
+
ctf.spherical_aberration = args.spherical_aberration
|
279
|
+
ctf.acceleration_voltage = args.acceleration_voltage * 1e3
|
280
|
+
ctf.correct_defocus_gradient = args.correct_defocus_gradient
|
281
|
+
|
282
|
+
if not needs_reconstruction:
|
283
|
+
template_filter.append(ctf)
|
284
|
+
elif isinstance(template_filter[-1], ReconstructFromTilt):
|
285
|
+
template_filter.insert(-1, ctf)
|
286
|
+
else:
|
287
|
+
template_filter.insert(0, ctf)
|
288
|
+
template_filter.insert(
|
289
|
+
1,
|
290
|
+
ReconstructFromTilt(
|
291
|
+
reconstruction_filter=args.reconstruction_filter,
|
292
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
293
|
+
),
|
294
|
+
)
|
295
|
+
|
296
|
+
if args.lowpass or args.highpass is not None:
|
297
|
+
lowpass, highpass = args.lowpass, args.highpass
|
298
|
+
if args.pass_format == "voxel":
|
299
|
+
if lowpass is not None:
|
300
|
+
lowpass = np.max(np.multiply(lowpass, template.sampling_rate))
|
301
|
+
if highpass is not None:
|
302
|
+
highpass = np.max(np.multiply(highpass, template.sampling_rate))
|
303
|
+
elif args.pass_format == "frequency":
|
304
|
+
if lowpass is not None:
|
305
|
+
lowpass = np.max(np.divide(template.sampling_rate, lowpass))
|
306
|
+
if highpass is not None:
|
307
|
+
highpass = np.max(np.divide(template.sampling_rate, highpass))
|
308
|
+
|
309
|
+
bandpass = BandPassFilter(
|
310
|
+
use_gaussian=args.no_pass_smooth,
|
311
|
+
lowpass=lowpass,
|
312
|
+
highpass=highpass,
|
313
|
+
sampling_rate=template.sampling_rate,
|
314
|
+
)
|
315
|
+
template_filter.append(bandpass)
|
316
|
+
target_filter.append(bandpass)
|
317
|
+
|
318
|
+
if args.whiten_spectrum:
|
319
|
+
whitening_filter = LinearWhiteningFilter()
|
320
|
+
template_filter.append(whitening_filter)
|
321
|
+
target_filter.append(whitening_filter)
|
322
|
+
|
323
|
+
needs_reconstruction = any(
|
324
|
+
[isinstance(t, ReconstructFromTilt) for t in template_filter]
|
325
|
+
)
|
326
|
+
if needs_reconstruction and args.reconstruction_filter is None:
|
327
|
+
warnings.warn(
|
328
|
+
"Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
|
329
|
+
)
|
330
|
+
|
331
|
+
template_filter = Compose(template_filter) if len(template_filter) else None
|
332
|
+
target_filter = Compose(target_filter) if len(target_filter) else None
|
333
|
+
|
334
|
+
return template_filter, target_filter
|
335
|
+
|
336
|
+
|
155
337
|
def parse_args():
|
156
338
|
parser = argparse.ArgumentParser(description="Perform template matching.")
|
157
339
|
|
@@ -224,13 +406,71 @@ def parse_args():
|
|
224
406
|
help="Template matching scoring function.",
|
225
407
|
)
|
226
408
|
scoring_group.add_argument(
|
409
|
+
"-p",
|
410
|
+
dest="peak_calling",
|
411
|
+
action="store_true",
|
412
|
+
default=False,
|
413
|
+
help="Perform peak calling instead of score aggregation.",
|
414
|
+
)
|
415
|
+
|
416
|
+
angular_group = parser.add_argument_group("Angular Sampling")
|
417
|
+
angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
|
418
|
+
|
419
|
+
angular_exclusive.add_argument(
|
227
420
|
"-a",
|
228
421
|
dest="angular_sampling",
|
229
422
|
type=check_positive,
|
230
|
-
default=
|
231
|
-
help="Angular sampling rate
|
423
|
+
default=None,
|
424
|
+
help="Angular sampling rate using optimized rotational sets."
|
232
425
|
"A lower number yields more rotations. Values >= 180 sample only the identity.",
|
233
426
|
)
|
427
|
+
angular_exclusive.add_argument(
|
428
|
+
"--cone_angle",
|
429
|
+
dest="cone_angle",
|
430
|
+
type=check_positive,
|
431
|
+
default=None,
|
432
|
+
help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
|
433
|
+
"narrow interval around a known orientation, e.g. for surface oversampling.",
|
434
|
+
)
|
435
|
+
angular_group.add_argument(
|
436
|
+
"--cone_sampling",
|
437
|
+
dest="cone_sampling",
|
438
|
+
type=check_positive,
|
439
|
+
default=None,
|
440
|
+
help="Sampling rate of the cone in degrees.",
|
441
|
+
)
|
442
|
+
angular_group.add_argument(
|
443
|
+
"--axis_angle",
|
444
|
+
dest="axis_angle",
|
445
|
+
type=check_positive,
|
446
|
+
default=360.0,
|
447
|
+
required=False,
|
448
|
+
help="Sampling angle along the z-axis of the cone. Defaults to 360.",
|
449
|
+
)
|
450
|
+
angular_group.add_argument(
|
451
|
+
"--axis_sampling",
|
452
|
+
dest="axis_sampling",
|
453
|
+
type=check_positive,
|
454
|
+
default=None,
|
455
|
+
required=False,
|
456
|
+
help="Sampling rate along the z-axis of the cone. Defaults to --cone_sampling.",
|
457
|
+
)
|
458
|
+
angular_group.add_argument(
|
459
|
+
"--axis_symmetry",
|
460
|
+
dest="axis_symmetry",
|
461
|
+
type=check_positive,
|
462
|
+
default=1,
|
463
|
+
required=False,
|
464
|
+
help="N-fold symmetry around z-axis of the cone.",
|
465
|
+
)
|
466
|
+
angular_group.add_argument(
|
467
|
+
"--no_use_optimized_set",
|
468
|
+
dest="no_use_optimized_set",
|
469
|
+
action="store_true",
|
470
|
+
default=False,
|
471
|
+
required=False,
|
472
|
+
help="Whether to use random uniform instead of optimized rotation sets.",
|
473
|
+
)
|
234
474
|
|
235
475
|
computation_group = parser.add_argument_group("Computation")
|
236
476
|
computation_group.add_argument(
|
@@ -276,21 +516,6 @@ def parse_args():
|
|
276
516
|
help="Fraction of available memory that can be used. Defaults to 0.85 and is "
|
277
517
|
"ignored if --ram is set",
|
278
518
|
)
|
279
|
-
computation_group.add_argument(
|
280
|
-
"--use_mixed_precision",
|
281
|
-
dest="use_mixed_precision",
|
282
|
-
action="store_true",
|
283
|
-
default=False,
|
284
|
-
help="Use float16 for real values operations where possible.",
|
285
|
-
)
|
286
|
-
computation_group.add_argument(
|
287
|
-
"--use_memmap",
|
288
|
-
dest="use_memmap",
|
289
|
-
action="store_true",
|
290
|
-
default=False,
|
291
|
-
help="Use memmaps to offload large data objects to disk. "
|
292
|
-
"Particularly useful for large inputs in combination with --use_gpu.",
|
293
|
-
)
|
294
519
|
computation_group.add_argument(
|
295
520
|
"--temp_directory",
|
296
521
|
dest="temp_directory",
|
@@ -315,11 +540,27 @@ def parse_args():
|
|
315
540
|
help="Resolution to highpass filter template and target to in the same unit "
|
316
541
|
"as the sampling rate of template and target (typically Ångstrom).",
|
317
542
|
)
|
543
|
+
filter_group.add_argument(
|
544
|
+
"--no_pass_smooth",
|
545
|
+
dest="no_pass_smooth",
|
546
|
+
action="store_false",
|
547
|
+
default=True,
|
548
|
+
help="Whether a hard edge filter should be used for --lowpass and --highpass.",
|
549
|
+
)
|
550
|
+
filter_group.add_argument(
|
551
|
+
"--pass_format",
|
552
|
+
dest="pass_format",
|
553
|
+
type=str,
|
554
|
+
required=False,
|
555
|
+
choices=["sampling_rate", "voxel", "frequency"],
|
556
|
+
help="How values passed to --lowpass and --highpass should be interpreted. "
|
557
|
+
"By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
|
558
|
+
)
|
318
559
|
filter_group.add_argument(
|
319
560
|
"--whiten_spectrum",
|
320
561
|
dest="whiten_spectrum",
|
321
562
|
action="store_true",
|
322
|
-
default=
|
563
|
+
default=None,
|
323
564
|
help="Apply spectral whitening to template and target based on target spectrum.",
|
324
565
|
)
|
325
566
|
filter_group.add_argument(
|
@@ -327,7 +568,7 @@ def parse_args():
|
|
327
568
|
dest="wedge_axes",
|
328
569
|
type=str,
|
329
570
|
required=False,
|
330
|
-
default=
|
571
|
+
default=None,
|
331
572
|
help="Indices of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open "
|
332
573
|
"in z-direction and tilted over the x axis.",
|
333
574
|
)
|
@@ -337,10 +578,10 @@ def parse_args():
|
|
337
578
|
type=str,
|
338
579
|
required=False,
|
339
580
|
default=None,
|
340
|
-
help="Path to a file
|
341
|
-
"start and stop stage tilt angle, e.g. 50,45, which
|
342
|
-
"mask. Alternatively, a tilt step size can be
|
343
|
-
"sample 5.0 degree tilt angle steps.",
|
581
|
+
help="Path to a tab-separated file containing the column angles and optionally "
|
582
|
+
" weights, or comma separated start and stop stage tilt angle, e.g. 50,45, which "
|
583
|
+
" yields a continuous wedge mask. Alternatively, a tilt step size can be "
|
584
|
+
"specified like 50,45:5.0 to sample 5.0 degree tilt angle steps.",
|
344
585
|
)
|
345
586
|
filter_group.add_argument(
|
346
587
|
"--tilt_weighting",
|
@@ -351,17 +592,93 @@ def parse_args():
|
|
351
592
|
default=None,
|
352
593
|
help="Weighting scheme used to reweight individual tilts. Available options: "
|
353
594
|
"angle (cosine based weighting), "
|
354
|
-
"relion (relion formalism for wedge weighting
|
595
|
+
"relion (relion formalism for wedge weighting) requires,"
|
355
596
|
"grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
|
356
|
-
"",
|
597
|
+
"relion and grigorieff require electron doses in --tilt_angles weights column.",
|
357
598
|
)
|
358
599
|
filter_group.add_argument(
|
600
|
+
"--reconstruction_filter",
|
601
|
+
dest="reconstruction_filter",
|
602
|
+
type=str,
|
603
|
+
required=False,
|
604
|
+
choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
|
605
|
+
default=None,
|
606
|
+
help="Filter applied when reconstructing (N+1)-D from N-D filters.",
|
607
|
+
)
|
608
|
+
filter_group.add_argument(
|
609
|
+
"--reconstruction_interpolation_order",
|
610
|
+
dest="reconstruction_interpolation_order",
|
611
|
+
type=int,
|
612
|
+
default=1,
|
613
|
+
required=False,
|
614
|
+
help="Analogous to --interpolation_order but for reconstruction.",
|
615
|
+
)
|
616
|
+
|
617
|
+
ctf_group = parser.add_argument_group("Contrast Transfer Function")
|
618
|
+
ctf_group.add_argument(
|
359
619
|
"--ctf_file",
|
360
620
|
dest="ctf_file",
|
361
621
|
type=str,
|
362
622
|
required=False,
|
363
623
|
default=None,
|
364
|
-
help="Path to a file with CTF parameters."
|
624
|
+
help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
|
625
|
+
"interpreted as tilt obtained at the angle specified in --tilt_angles. ",
|
626
|
+
)
|
627
|
+
ctf_group.add_argument(
|
628
|
+
"--defocus",
|
629
|
+
dest="defocus",
|
630
|
+
type=float,
|
631
|
+
required=False,
|
632
|
+
default=None,
|
633
|
+
help="Defocus in units of sampling rate (typically Ångstrom). "
|
634
|
+
"Superseded by --ctf_file.",
|
635
|
+
)
|
636
|
+
ctf_group.add_argument(
|
637
|
+
"--phase_shift",
|
638
|
+
dest="phase_shift",
|
639
|
+
type=float,
|
640
|
+
required=False,
|
641
|
+
default=0,
|
642
|
+
help="Phase shift in degrees. Superseded by --ctf_file.",
|
643
|
+
)
|
644
|
+
ctf_group.add_argument(
|
645
|
+
"--acceleration_voltage",
|
646
|
+
dest="acceleration_voltage",
|
647
|
+
type=float,
|
648
|
+
required=False,
|
649
|
+
default=300,
|
650
|
+
help="Acceleration voltage in kV, defaults to 300.",
|
651
|
+
)
|
652
|
+
ctf_group.add_argument(
|
653
|
+
"--spherical_aberration",
|
654
|
+
dest="spherical_aberration",
|
655
|
+
type=float,
|
656
|
+
required=False,
|
657
|
+
default=2.7e7,
|
658
|
+
help="Spherical aberration in units of sampling rate (typically Ångstrom).",
|
659
|
+
)
|
660
|
+
ctf_group.add_argument(
|
661
|
+
"--amplitude_contrast",
|
662
|
+
dest="amplitude_contrast",
|
663
|
+
type=float,
|
664
|
+
required=False,
|
665
|
+
default=0.07,
|
666
|
+
help="Amplitude contrast, defaults to 0.07.",
|
667
|
+
)
|
668
|
+
ctf_group.add_argument(
|
669
|
+
"--no_flip_phase",
|
670
|
+
dest="no_flip_phase",
|
671
|
+
action="store_false",
|
672
|
+
required=False,
|
673
|
+
help="Whether the phase of the computed CTF should not be flipped.",
|
674
|
+
)
|
675
|
+
ctf_group.add_argument(
|
676
|
+
"--correct_defocus_gradient",
|
677
|
+
dest="correct_defocus_gradient",
|
678
|
+
action="store_true",
|
679
|
+
required=False,
|
680
|
+
help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
|
681
|
+
"defocus gradients.",
|
365
682
|
)
|
366
683
|
|
367
684
|
performance_group = parser.add_argument_group("Performance")
|
@@ -413,6 +730,21 @@ def parse_args():
|
|
413
730
|
help="Spline interpolation used for template rotations. If less than zero "
|
414
731
|
"no interpolation is performed.",
|
415
732
|
)
|
733
|
+
performance_group.add_argument(
|
734
|
+
"--use_mixed_precision",
|
735
|
+
dest="use_mixed_precision",
|
736
|
+
action="store_true",
|
737
|
+
default=False,
|
738
|
+
help="Use float16 for real values operations where possible.",
|
739
|
+
)
|
740
|
+
performance_group.add_argument(
|
741
|
+
"--use_memmap",
|
742
|
+
dest="use_memmap",
|
743
|
+
action="store_true",
|
744
|
+
default=False,
|
745
|
+
help="Use memmaps to offload large data objects to disk. "
|
746
|
+
"Particularly useful for large inputs in combination with --use_gpu.",
|
747
|
+
)
|
416
748
|
|
417
749
|
analyzer_group = parser.add_argument_group("Analyzer")
|
418
750
|
analyzer_group.add_argument(
|
@@ -423,14 +755,9 @@ def parse_args():
|
|
423
755
|
default=0,
|
424
756
|
help="Minimum template matching scores to consider for analysis.",
|
425
757
|
)
|
426
|
-
|
427
|
-
"-p",
|
428
|
-
dest="peak_calling",
|
429
|
-
action="store_true",
|
430
|
-
default=False,
|
431
|
-
help="Perform peak calling instead of score aggregation.",
|
432
|
-
)
|
758
|
+
|
433
759
|
args = parser.parse_args()
|
760
|
+
args.version = __version__
|
434
761
|
|
435
762
|
if args.interpolation_order < 0:
|
436
763
|
args.interpolation_order = None
|
@@ -467,94 +794,22 @@ def parse_args():
|
|
467
794
|
int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
468
795
|
]
|
469
796
|
|
470
|
-
if args.wedge_axes is not None:
|
471
|
-
args.wedge_axes = [int(x) for x in args.wedge_axes.split(",")]
|
472
|
-
|
473
|
-
if args.tilt_angles is not None and args.wedge_axes is None:
|
474
|
-
raise ValueError("Wedge axes have to be specified with tilt angles.")
|
475
|
-
|
476
|
-
if args.ctf_file is not None and args.wedge_axes is None:
|
477
|
-
raise ValueError("Wedge axes have to be specified with CTF parameters.")
|
478
|
-
if args.ctf_file is not None and args.tilt_angles is None:
|
479
|
-
raise ValueError("Angles have to be specified with CTF parameters.")
|
480
|
-
|
481
|
-
return args
|
482
|
-
|
483
|
-
|
484
|
-
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
485
|
-
from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
|
486
|
-
|
487
|
-
template_filter, target_filter = [], []
|
488
797
|
if args.tilt_angles is not None:
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
wedge = Wedge.from_file(args.tilt_angles)
|
497
|
-
wedge.weight_type = args.tilt_weighting
|
498
|
-
except FileNotFoundError:
|
499
|
-
tilt_step = None
|
500
|
-
tilt_start, tilt_stop = args.tilt_angles.split(",")
|
501
|
-
if ":" in tilt_stop:
|
502
|
-
tilt_stop, tilt_step = tilt_stop.split(":")
|
503
|
-
tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
|
504
|
-
tilt_angles = None
|
505
|
-
if tilt_step is not None:
|
506
|
-
tilt_step = float(tilt_step)
|
507
|
-
tilt_angles = np.arange(
|
508
|
-
-tilt_start, tilt_stop + tilt_step, tilt_step
|
509
|
-
).tolist()
|
510
|
-
wedge = WedgeReconstructed(
|
511
|
-
angles=tilt_angles,
|
512
|
-
start_tilt=tilt_start,
|
513
|
-
stop_tilt=tilt_stop,
|
514
|
-
)
|
515
|
-
wedge.opening_axis = args.wedge_axes[0]
|
516
|
-
wedge.tilt_axis = args.wedge_axes[1]
|
517
|
-
wedge.sampling_rate = template.sampling_rate
|
518
|
-
template_filter.append(wedge)
|
519
|
-
if not isinstance(wedge, WedgeReconstructed):
|
520
|
-
template_filter.append(ReconstructFromTilt())
|
798
|
+
if args.wedge_axes is None:
|
799
|
+
raise ValueError("Need to specify --wedge_axes when --tilt_angles is set.")
|
800
|
+
if not exists(args.tilt_angles):
|
801
|
+
try:
|
802
|
+
float(args.tilt_angles.split(",")[0])
|
803
|
+
except ValueError:
|
804
|
+
raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
|
521
805
|
|
522
|
-
if args.ctf_file is not None:
|
523
|
-
|
524
|
-
|
525
|
-
ctf = CTF.from_file(args.ctf_file)
|
526
|
-
ctf.tilt_axis = args.wedge_axes[1]
|
527
|
-
ctf.opening_axis = args.wedge_axes[0]
|
528
|
-
template_filter.append(ctf)
|
529
|
-
if isinstance(template_filter[-1], ReconstructFromTilt):
|
530
|
-
template_filter.insert(-1, ctf)
|
531
|
-
else:
|
532
|
-
template_filter.insert(0, ctf)
|
533
|
-
template_filter.isnert(1, ReconstructFromTilt())
|
534
|
-
|
535
|
-
if args.lowpass or args.highpass is not None:
|
536
|
-
from tme.preprocessing import BandPassFilter
|
537
|
-
|
538
|
-
bandpass = BandPassFilter(
|
539
|
-
use_gaussian=True,
|
540
|
-
lowpass=args.lowpass,
|
541
|
-
highpass=args.highpass,
|
542
|
-
sampling_rate=template.sampling_rate,
|
543
|
-
)
|
544
|
-
template_filter.append(bandpass)
|
545
|
-
target_filter.append(bandpass)
|
546
|
-
|
547
|
-
if args.whiten_spectrum:
|
548
|
-
from tme.preprocessing import LinearWhiteningFilter
|
549
|
-
|
550
|
-
whitening_filter = LinearWhiteningFilter()
|
551
|
-
template_filter.append(whitening_filter)
|
552
|
-
target_filter.append(whitening_filter)
|
806
|
+
if args.ctf_file is not None and args.tilt_angles is None:
|
807
|
+
raise ValueError("Need to specify --tilt_angles when --ctf_file is set.")
|
553
808
|
|
554
|
-
|
555
|
-
|
809
|
+
if args.wedge_axes is not None:
|
810
|
+
args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
|
556
811
|
|
557
|
-
return
|
812
|
+
return args
|
558
813
|
|
559
814
|
|
560
815
|
def main():
|
@@ -566,17 +821,20 @@ def main():
|
|
566
821
|
try:
|
567
822
|
template = Density.from_file(args.template)
|
568
823
|
except Exception:
|
824
|
+
drop = target.metadata.get("batch_dimension", ())
|
825
|
+
keep = [i not in drop for i in range(target.data.ndim)]
|
569
826
|
template = Density.from_structure(
|
570
827
|
filename_or_structure=args.template,
|
571
|
-
sampling_rate=target.sampling_rate,
|
828
|
+
sampling_rate=target.sampling_rate[keep],
|
572
829
|
)
|
573
830
|
|
574
|
-
if
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
831
|
+
if target.sampling_rate.size == template.sampling_rate.size:
|
832
|
+
if not np.allclose(target.sampling_rate, template.sampling_rate):
|
833
|
+
print(
|
834
|
+
f"Resampling template to {target.sampling_rate}. "
|
835
|
+
"Consider providing a template with the same sampling rate as the target."
|
836
|
+
)
|
837
|
+
template = template.resample(target.sampling_rate, order=3)
|
580
838
|
|
581
839
|
template_mask = load_and_validate_mask(
|
582
840
|
mask_target=template, mask_path=args.template_mask
|
@@ -709,31 +967,52 @@ def main():
|
|
709
967
|
if args.memory is None:
|
710
968
|
args.memory = int(args.memory_scaling * available_memory)
|
711
969
|
|
712
|
-
|
713
|
-
if args.
|
714
|
-
|
970
|
+
callback_class = MaxScoreOverRotations
|
971
|
+
if args.peak_calling:
|
972
|
+
callback_class = PeakCallerMaximumFilter
|
973
|
+
|
974
|
+
matching_data = MatchingData(
|
975
|
+
target=target,
|
976
|
+
template=template.data,
|
977
|
+
target_mask=target_mask,
|
978
|
+
template_mask=template_mask,
|
979
|
+
invert_target=args.invert_target_contrast,
|
980
|
+
rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
|
981
|
+
)
|
715
982
|
|
716
|
-
|
983
|
+
template_filter, target_filter = setup_filter(args, template, target)
|
984
|
+
matching_data.template_filter = template_filter
|
985
|
+
matching_data.target_filter = target_filter
|
986
|
+
|
987
|
+
target_dims = target.metadata.get("batch_dimension", None)
|
988
|
+
matching_data._set_batch_dimension(target_dims=target_dims, template_dims=None)
|
989
|
+
args.score = "FLC2" if target_dims is not None else args.score
|
990
|
+
args.target_batch, args.template_batch = target_dims, None
|
991
|
+
|
992
|
+
template_box = matching_data._output_template_shape
|
717
993
|
if not args.pad_fourier:
|
718
994
|
template_box = np.ones(len(template_box), dtype=int)
|
719
995
|
|
720
|
-
|
721
|
-
|
722
|
-
|
996
|
+
target_padding = np.zeros(
|
997
|
+
(backend.size(matching_data._output_template_shape)), dtype=int
|
998
|
+
)
|
999
|
+
if args.pad_target_edges:
|
1000
|
+
target_padding = matching_data._output_template_shape
|
723
1001
|
|
724
1002
|
splits, schedule = compute_parallelization_schedule(
|
725
1003
|
shape1=target.shape,
|
726
|
-
shape2=template_box,
|
727
|
-
shape1_padding=target_padding,
|
1004
|
+
shape2=tuple(int(x) for x in template_box),
|
1005
|
+
shape1_padding=tuple(int(x) for x in target_padding),
|
728
1006
|
max_cores=args.cores,
|
729
1007
|
max_ram=args.memory,
|
730
1008
|
split_only_outer=args.use_gpu,
|
731
1009
|
matching_method=args.score,
|
732
1010
|
analyzer_method=callback_class.__name__,
|
733
1011
|
backend=backend._backend_name,
|
734
|
-
float_nbytes=backend.datatype_bytes(backend.
|
1012
|
+
float_nbytes=backend.datatype_bytes(backend._float_dtype),
|
735
1013
|
complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
|
736
|
-
integer_nbytes=backend.datatype_bytes(backend.
|
1014
|
+
integer_nbytes=backend.datatype_bytes(backend._int_dtype),
|
1015
|
+
split_axes=target_dims,
|
737
1016
|
)
|
738
1017
|
|
739
1018
|
if splits is None:
|
@@ -743,32 +1022,7 @@ def main():
|
|
743
1022
|
)
|
744
1023
|
exit(-1)
|
745
1024
|
|
746
|
-
analyzer_args = {
|
747
|
-
"score_threshold": args.score_threshold,
|
748
|
-
"number_of_peaks": 1000,
|
749
|
-
"convolution_mode": "valid",
|
750
|
-
"use_memmap": args.use_memmap,
|
751
|
-
}
|
752
|
-
|
753
1025
|
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
754
|
-
matching_data = MatchingData(target=target, template=template.data)
|
755
|
-
matching_data.rotations = get_rotation_matrices(
|
756
|
-
angular_sampling=args.angular_sampling, dim=target.data.ndim
|
757
|
-
)
|
758
|
-
if args.angular_sampling >= 180:
|
759
|
-
ndim = target.data.ndim
|
760
|
-
matching_data.rotations = np.eye(ndim).reshape(1, ndim, ndim)
|
761
|
-
|
762
|
-
template_filter, target_filter = setup_filter(args, template, target)
|
763
|
-
matching_data.template_filter = template_filter
|
764
|
-
matching_data.target_filter = target_filter
|
765
|
-
|
766
|
-
matching_data._invert_target = args.invert_target_contrast
|
767
|
-
if target_mask is not None:
|
768
|
-
matching_data.target_mask = target_mask
|
769
|
-
if template_mask is not None:
|
770
|
-
matching_data.template_mask = template_mask.data
|
771
|
-
|
772
1026
|
n_splits = np.prod(list(splits.values()))
|
773
1027
|
target_split = ", ".join(
|
774
1028
|
[":".join([str(x) for x in axis]) for axis in splits.items()]
|
@@ -798,10 +1052,45 @@ def main():
|
|
798
1052
|
label_width=max(len(key) for key in options.keys()) + 2,
|
799
1053
|
)
|
800
1054
|
|
801
|
-
|
1055
|
+
filter_args = {
|
1056
|
+
"Lowpass": args.lowpass,
|
1057
|
+
"Highpass": args.highpass,
|
1058
|
+
"Smooth Pass": args.no_pass_smooth,
|
1059
|
+
"Pass Format": args.pass_format,
|
1060
|
+
"Spectral Whitening": args.whiten_spectrum,
|
1061
|
+
"Wedge Axes": args.wedge_axes,
|
1062
|
+
"Tilt Angles": args.tilt_angles,
|
1063
|
+
"Tilt Weighting": args.tilt_weighting,
|
1064
|
+
"Reconstruction Filter": args.reconstruction_filter,
|
1065
|
+
}
|
1066
|
+
if args.ctf_file is not None or args.defocus is not None:
|
1067
|
+
filter_args["CTF File"] = args.ctf_file
|
1068
|
+
filter_args["Defocus"] = args.defocus
|
1069
|
+
filter_args["Phase Shift"] = args.phase_shift
|
1070
|
+
filter_args["No Flip Phase"] = args.no_flip_phase
|
1071
|
+
filter_args["Acceleration Voltage"] = args.acceleration_voltage
|
1072
|
+
filter_args["Spherical Aberration"] = args.spherical_aberration
|
1073
|
+
filter_args["Amplitude Contrast"] = args.amplitude_contrast
|
1074
|
+
filter_args["Correct Defocus"] = args.correct_defocus_gradient
|
1075
|
+
|
1076
|
+
filter_args = {k: v for k, v in filter_args.items() if v is not None}
|
1077
|
+
if len(filter_args):
|
1078
|
+
print_block(
|
1079
|
+
name="Filters",
|
1080
|
+
data=filter_args,
|
1081
|
+
label_width=max(len(key) for key in options.keys()) + 2,
|
1082
|
+
)
|
1083
|
+
|
1084
|
+
analyzer_args = {
|
1085
|
+
"score_threshold": args.score_threshold,
|
1086
|
+
"number_of_peaks": 1000,
|
1087
|
+
"convolution_mode": "valid",
|
1088
|
+
"use_memmap": args.use_memmap,
|
1089
|
+
}
|
1090
|
+
analyzer_args = {"Analyzer": callback_class, **analyzer_args}
|
802
1091
|
print_block(
|
803
1092
|
name="Score Analysis Options",
|
804
|
-
data=
|
1093
|
+
data=analyzer_args,
|
805
1094
|
label_width=max(len(key) for key in options.keys()) + 2,
|
806
1095
|
)
|
807
1096
|
print("\n" + "-" * 80)
|
@@ -832,16 +1121,16 @@ def main():
|
|
832
1121
|
candidates[0] *= target_mask.data
|
833
1122
|
with warnings.catch_warnings():
|
834
1123
|
warnings.simplefilter("ignore", category=UserWarning)
|
1124
|
+
nbytes = backend.datatype_bytes(backend._float_dtype)
|
1125
|
+
dtype = np.float32 if nbytes == 4 else np.float16
|
1126
|
+
rot_dim = matching_data.rotations.shape[1]
|
835
1127
|
candidates[3] = {
|
836
1128
|
x: euler_from_rotationmatrix(
|
837
|
-
np.frombuffer(i, dtype=
|
838
|
-
candidates[0].ndim, candidates[0].ndim
|
839
|
-
)
|
1129
|
+
np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
|
840
1130
|
)
|
841
1131
|
for i, x in candidates[3].items()
|
842
1132
|
}
|
843
|
-
|
844
|
-
candidates.append((target.origin, template.origin, target.sampling_rate, args))
|
1133
|
+
candidates.append((target.origin, template.origin, template.sampling_rate, args))
|
845
1134
|
write_pickle(data=candidates, filename=args.output)
|
846
1135
|
|
847
1136
|
runtime = time() - start
|