pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.2.data/scripts/match_template.py +1187 -0
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
- pytme-0.2.2.dist-info/METADATA +91 -0
- pytme-0.2.2.dist-info/RECORD +74 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +126 -87
- scripts/match_template.py +596 -209
- scripts/match_template_filters.py +571 -223
- scripts/postprocess.py +170 -71
- scripts/preprocessor_gui.py +179 -86
- scripts/refine_matches.py +567 -159
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +627 -855
- tme/backends/__init__.py +41 -11
- tme/backends/_jax_utils.py +185 -0
- tme/backends/cupy_backend.py +120 -225
- tme/backends/jax_backend.py +282 -0
- tme/backends/matching_backend.py +464 -388
- tme/backends/mlx_backend.py +45 -68
- tme/backends/npfftw_backend.py +256 -514
- tme/backends/pytorch_backend.py +41 -154
- tme/density.py +312 -421
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +366 -303
- tme/matching_exhaustive.py +279 -1521
- tme/matching_optimization.py +234 -129
- tme/matching_scores.py +884 -0
- tme/matching_utils.py +281 -387
- tme/memory.py +377 -0
- tme/orientations.py +226 -66
- tme/parser.py +3 -4
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +217 -0
- tme/preprocessing/composable_filter.py +31 -0
- tme/preprocessing/compose.py +55 -0
- tme/preprocessing/frequency_filters.py +388 -0
- tme/preprocessing/tilt_series.py +1011 -0
- tme/preprocessor.py +574 -530
- tme/structure.py +495 -189
- tme/types.py +5 -3
- pytme-0.2.0b0.data/scripts/match_template.py +0 -800
- pytme-0.2.0b0.dist-info/METADATA +0 -73
- pytme-0.2.0b0.dist-info/RECORD +0 -66
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
scripts/match_template.py
CHANGED
@@ -8,19 +8,19 @@
|
|
8
8
|
import os
|
9
9
|
import argparse
|
10
10
|
import warnings
|
11
|
-
import importlib.util
|
12
11
|
from sys import exit
|
13
12
|
from time import time
|
13
|
+
from typing import Tuple
|
14
14
|
from copy import deepcopy
|
15
|
-
from os.path import abspath
|
15
|
+
from os.path import abspath, exists
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
|
19
|
-
from tme import Density,
|
19
|
+
from tme import Density, __version__
|
20
20
|
from tme.matching_utils import (
|
21
21
|
get_rotation_matrices,
|
22
|
+
get_rotations_around_vector,
|
22
23
|
compute_parallelization_schedule,
|
23
|
-
euler_from_rotationmatrix,
|
24
24
|
scramble_phases,
|
25
25
|
generate_tempfile_name,
|
26
26
|
write_pickle,
|
@@ -31,7 +31,8 @@ from tme.analyzer import (
|
|
31
31
|
MaxScoreOverRotations,
|
32
32
|
PeakCallerMaximumFilter,
|
33
33
|
)
|
34
|
-
from tme.backends import backend
|
34
|
+
from tme.backends import backend as be
|
35
|
+
from tme.preprocessing import Compose
|
35
36
|
|
36
37
|
|
37
38
|
def get_func_fullname(func) -> str:
|
@@ -49,7 +50,7 @@ def print_block(name: str, data: dict, label_width=20) -> None:
|
|
49
50
|
|
50
51
|
def print_entry() -> None:
|
51
52
|
width = 80
|
52
|
-
text = f"
|
53
|
+
text = f" pytme v{__version__} "
|
53
54
|
padding_total = width - len(text) - 2
|
54
55
|
padding_left = padding_total // 2
|
55
56
|
padding_right = padding_total - padding_left
|
@@ -150,8 +151,200 @@ def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
|
|
150
151
|
return True
|
151
152
|
|
152
153
|
|
154
|
+
def parse_rotation_logic(args, ndim):
|
155
|
+
if args.angular_sampling is not None:
|
156
|
+
rotations = get_rotation_matrices(
|
157
|
+
angular_sampling=args.angular_sampling,
|
158
|
+
dim=ndim,
|
159
|
+
use_optimized_set=not args.no_use_optimized_set,
|
160
|
+
)
|
161
|
+
if args.angular_sampling >= 180:
|
162
|
+
rotations = np.eye(ndim).reshape(1, ndim, ndim)
|
163
|
+
return rotations
|
164
|
+
|
165
|
+
if args.axis_sampling is None:
|
166
|
+
args.axis_sampling = args.cone_sampling
|
167
|
+
|
168
|
+
rotations = get_rotations_around_vector(
|
169
|
+
cone_angle=args.cone_angle,
|
170
|
+
cone_sampling=args.cone_sampling,
|
171
|
+
axis_angle=args.axis_angle,
|
172
|
+
axis_sampling=args.axis_sampling,
|
173
|
+
n_symmetry=args.axis_symmetry,
|
174
|
+
)
|
175
|
+
return rotations
|
176
|
+
|
177
|
+
|
178
|
+
# TODO: Think about whether wedge mask should also be added to target
|
179
|
+
# For now leave it at the cost of incorrect upper bound on the scores
|
180
|
+
def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
|
181
|
+
from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
|
182
|
+
from tme.preprocessing.tilt_series import (
|
183
|
+
Wedge,
|
184
|
+
WedgeReconstructed,
|
185
|
+
ReconstructFromTilt,
|
186
|
+
)
|
187
|
+
|
188
|
+
template_filter, target_filter = [], []
|
189
|
+
if args.tilt_angles is not None:
|
190
|
+
try:
|
191
|
+
wedge = Wedge.from_file(args.tilt_angles)
|
192
|
+
wedge.weight_type = args.tilt_weighting
|
193
|
+
if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
|
194
|
+
wedge = WedgeReconstructed(
|
195
|
+
angles=wedge.angles, weight_wedge=args.tilt_weighting == "angle"
|
196
|
+
)
|
197
|
+
except FileNotFoundError:
|
198
|
+
tilt_step, create_continuous_wedge = None, True
|
199
|
+
tilt_start, tilt_stop = args.tilt_angles.split(",")
|
200
|
+
if ":" in tilt_stop:
|
201
|
+
create_continuous_wedge = False
|
202
|
+
tilt_stop, tilt_step = tilt_stop.split(":")
|
203
|
+
tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
|
204
|
+
tilt_angles = (tilt_start, tilt_stop)
|
205
|
+
if tilt_step is not None:
|
206
|
+
tilt_step = float(tilt_step)
|
207
|
+
tilt_angles = np.arange(
|
208
|
+
-tilt_start, tilt_stop + tilt_step, tilt_step
|
209
|
+
).tolist()
|
210
|
+
|
211
|
+
if args.tilt_weighting is not None and tilt_step is None:
|
212
|
+
raise ValueError(
|
213
|
+
"Tilt weighting is not supported for continuous wedges."
|
214
|
+
)
|
215
|
+
if args.tilt_weighting not in ("angle", None):
|
216
|
+
raise ValueError(
|
217
|
+
"Tilt weighting schemes other than 'angle' or 'None' require "
|
218
|
+
"a specification of electron doses via --tilt_angles."
|
219
|
+
)
|
220
|
+
|
221
|
+
wedge = Wedge(
|
222
|
+
angles=tilt_angles,
|
223
|
+
opening_axis=args.wedge_axes[0],
|
224
|
+
tilt_axis=args.wedge_axes[1],
|
225
|
+
shape=None,
|
226
|
+
weight_type=None,
|
227
|
+
weights=np.ones_like(tilt_angles),
|
228
|
+
)
|
229
|
+
if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
|
230
|
+
wedge = WedgeReconstructed(
|
231
|
+
angles=tilt_angles,
|
232
|
+
weight_wedge=args.tilt_weighting == "angle",
|
233
|
+
create_continuous_wedge=create_continuous_wedge,
|
234
|
+
)
|
235
|
+
|
236
|
+
wedge.opening_axis = args.wedge_axes[0]
|
237
|
+
wedge.tilt_axis = args.wedge_axes[1]
|
238
|
+
wedge.sampling_rate = template.sampling_rate
|
239
|
+
template_filter.append(wedge)
|
240
|
+
if not isinstance(wedge, WedgeReconstructed):
|
241
|
+
template_filter.append(
|
242
|
+
ReconstructFromTilt(
|
243
|
+
reconstruction_filter=args.reconstruction_filter,
|
244
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
245
|
+
)
|
246
|
+
)
|
247
|
+
|
248
|
+
if args.ctf_file is not None or args.defocus is not None:
|
249
|
+
from tme.preprocessing.tilt_series import CTF
|
250
|
+
|
251
|
+
needs_reconstruction = True
|
252
|
+
if args.ctf_file is not None:
|
253
|
+
ctf = CTF.from_file(args.ctf_file)
|
254
|
+
n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
|
255
|
+
if n_tilts_ctfs != n_tils_angles:
|
256
|
+
raise ValueError(
|
257
|
+
f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
|
258
|
+
f"recieved {n_tils_angles} tilt angles. Expected one angle "
|
259
|
+
"per micrograph."
|
260
|
+
)
|
261
|
+
ctf.angles = wedge.angles
|
262
|
+
ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
|
263
|
+
else:
|
264
|
+
needs_reconstruction = False
|
265
|
+
ctf = CTF(
|
266
|
+
defocus_x=[args.defocus],
|
267
|
+
phase_shift=[args.phase_shift],
|
268
|
+
defocus_y=None,
|
269
|
+
angles=[0],
|
270
|
+
shape=None,
|
271
|
+
return_real_fourier=True,
|
272
|
+
)
|
273
|
+
ctf.sampling_rate = template.sampling_rate
|
274
|
+
ctf.flip_phase = args.no_flip_phase
|
275
|
+
ctf.amplitude_contrast = args.amplitude_contrast
|
276
|
+
ctf.spherical_aberration = args.spherical_aberration
|
277
|
+
ctf.acceleration_voltage = args.acceleration_voltage * 1e3
|
278
|
+
ctf.correct_defocus_gradient = args.correct_defocus_gradient
|
279
|
+
|
280
|
+
if not needs_reconstruction:
|
281
|
+
template_filter.append(ctf)
|
282
|
+
elif isinstance(template_filter[-1], ReconstructFromTilt):
|
283
|
+
template_filter.insert(-1, ctf)
|
284
|
+
else:
|
285
|
+
template_filter.insert(0, ctf)
|
286
|
+
template_filter.insert(
|
287
|
+
1,
|
288
|
+
ReconstructFromTilt(
|
289
|
+
reconstruction_filter=args.reconstruction_filter,
|
290
|
+
interpolation_order=args.reconstruction_interpolation_order,
|
291
|
+
),
|
292
|
+
)
|
293
|
+
|
294
|
+
if args.lowpass or args.highpass is not None:
|
295
|
+
lowpass, highpass = args.lowpass, args.highpass
|
296
|
+
if args.pass_format == "voxel":
|
297
|
+
if lowpass is not None:
|
298
|
+
lowpass = np.max(np.multiply(lowpass, template.sampling_rate))
|
299
|
+
if highpass is not None:
|
300
|
+
highpass = np.max(np.multiply(highpass, template.sampling_rate))
|
301
|
+
elif args.pass_format == "frequency":
|
302
|
+
if lowpass is not None:
|
303
|
+
lowpass = np.max(np.divide(template.sampling_rate, lowpass))
|
304
|
+
if highpass is not None:
|
305
|
+
highpass = np.max(np.divide(template.sampling_rate, highpass))
|
306
|
+
|
307
|
+
try:
|
308
|
+
if args.lowpass >= args.highpass:
|
309
|
+
warnings.warn("--lowpass should be smaller than --highpass.")
|
310
|
+
except Exception:
|
311
|
+
pass
|
312
|
+
|
313
|
+
bandpass = BandPassFilter(
|
314
|
+
use_gaussian=args.no_pass_smooth,
|
315
|
+
lowpass=lowpass,
|
316
|
+
highpass=highpass,
|
317
|
+
sampling_rate=template.sampling_rate,
|
318
|
+
)
|
319
|
+
template_filter.append(bandpass)
|
320
|
+
|
321
|
+
if not args.no_filter_target:
|
322
|
+
target_filter.append(bandpass)
|
323
|
+
|
324
|
+
if args.whiten_spectrum:
|
325
|
+
whitening_filter = LinearWhiteningFilter()
|
326
|
+
template_filter.append(whitening_filter)
|
327
|
+
target_filter.append(whitening_filter)
|
328
|
+
|
329
|
+
needs_reconstruction = any(
|
330
|
+
[isinstance(t, ReconstructFromTilt) for t in template_filter]
|
331
|
+
)
|
332
|
+
if needs_reconstruction and args.reconstruction_filter is None:
|
333
|
+
warnings.warn(
|
334
|
+
"Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
|
335
|
+
)
|
336
|
+
|
337
|
+
template_filter = Compose(template_filter) if len(template_filter) else None
|
338
|
+
target_filter = Compose(target_filter) if len(target_filter) else None
|
339
|
+
|
340
|
+
return template_filter, target_filter
|
341
|
+
|
342
|
+
|
153
343
|
def parse_args():
|
154
|
-
parser = argparse.ArgumentParser(
|
344
|
+
parser = argparse.ArgumentParser(
|
345
|
+
description="Perform template matching.",
|
346
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
347
|
+
)
|
155
348
|
|
156
349
|
io_group = parser.add_argument_group("Input / Output")
|
157
350
|
io_group.add_argument(
|
@@ -221,22 +414,65 @@ def parse_args():
|
|
221
414
|
choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
|
222
415
|
help="Template matching scoring function.",
|
223
416
|
)
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
help="Perform peak calling instead of score aggregation.",
|
230
|
-
)
|
231
|
-
scoring_group.add_argument(
|
417
|
+
|
418
|
+
angular_group = parser.add_argument_group("Angular Sampling")
|
419
|
+
angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
|
420
|
+
|
421
|
+
angular_exclusive.add_argument(
|
232
422
|
"-a",
|
233
423
|
dest="angular_sampling",
|
234
424
|
type=check_positive,
|
235
|
-
default=
|
236
|
-
help="Angular sampling rate
|
425
|
+
default=None,
|
426
|
+
help="Angular sampling rate using optimized rotational sets."
|
237
427
|
"A lower number yields more rotations. Values >= 180 sample only the identity.",
|
238
428
|
)
|
239
|
-
|
429
|
+
angular_exclusive.add_argument(
|
430
|
+
"--cone_angle",
|
431
|
+
dest="cone_angle",
|
432
|
+
type=check_positive,
|
433
|
+
default=None,
|
434
|
+
help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
|
435
|
+
"narrow interval around a known orientation, e.g. for surface oversampling.",
|
436
|
+
)
|
437
|
+
angular_group.add_argument(
|
438
|
+
"--cone_sampling",
|
439
|
+
dest="cone_sampling",
|
440
|
+
type=check_positive,
|
441
|
+
default=None,
|
442
|
+
help="Sampling rate of the cone in degrees.",
|
443
|
+
)
|
444
|
+
angular_group.add_argument(
|
445
|
+
"--axis_angle",
|
446
|
+
dest="axis_angle",
|
447
|
+
type=check_positive,
|
448
|
+
default=360.0,
|
449
|
+
required=False,
|
450
|
+
help="Sampling angle along the z-axis of the cone.",
|
451
|
+
)
|
452
|
+
angular_group.add_argument(
|
453
|
+
"--axis_sampling",
|
454
|
+
dest="axis_sampling",
|
455
|
+
type=check_positive,
|
456
|
+
default=None,
|
457
|
+
required=False,
|
458
|
+
help="Sampling rate along the z-axis of the cone. Defaults to --cone_sampling.",
|
459
|
+
)
|
460
|
+
angular_group.add_argument(
|
461
|
+
"--axis_symmetry",
|
462
|
+
dest="axis_symmetry",
|
463
|
+
type=check_positive,
|
464
|
+
default=1,
|
465
|
+
required=False,
|
466
|
+
help="N-fold symmetry around z-axis of the cone.",
|
467
|
+
)
|
468
|
+
angular_group.add_argument(
|
469
|
+
"--no_use_optimized_set",
|
470
|
+
dest="no_use_optimized_set",
|
471
|
+
action="store_true",
|
472
|
+
default=False,
|
473
|
+
required=False,
|
474
|
+
help="Whether to use random uniform instead of optimized rotation sets.",
|
475
|
+
)
|
240
476
|
|
241
477
|
computation_group = parser.add_argument_group("Computation")
|
242
478
|
computation_group.add_argument(
|
@@ -279,23 +515,7 @@ def parse_args():
|
|
279
515
|
required=False,
|
280
516
|
type=float,
|
281
517
|
default=0.85,
|
282
|
-
help="Fraction of available memory
|
283
|
-
"ignored if --ram is set",
|
284
|
-
)
|
285
|
-
computation_group.add_argument(
|
286
|
-
"--use_mixed_precision",
|
287
|
-
dest="use_mixed_precision",
|
288
|
-
action="store_true",
|
289
|
-
default=False,
|
290
|
-
help="Use float16 for real values operations where possible.",
|
291
|
-
)
|
292
|
-
computation_group.add_argument(
|
293
|
-
"--use_memmap",
|
294
|
-
dest="use_memmap",
|
295
|
-
action="store_true",
|
296
|
-
default=False,
|
297
|
-
help="Use memmaps to offload large data objects to disk. "
|
298
|
-
"Particularly useful for large inputs in combination with --use_gpu.",
|
518
|
+
help="Fraction of available memory to be used. Ignored if --ram is set.",
|
299
519
|
)
|
300
520
|
computation_group.add_argument(
|
301
521
|
"--temp_directory",
|
@@ -303,64 +523,176 @@ def parse_args():
|
|
303
523
|
default=None,
|
304
524
|
help="Directory for temporary objects. Faster I/O improves runtime.",
|
305
525
|
)
|
306
|
-
|
526
|
+
computation_group.add_argument(
|
527
|
+
"--backend",
|
528
|
+
dest="backend",
|
529
|
+
default=None,
|
530
|
+
choices=be.available_backends(),
|
531
|
+
help="[Expert] Overwrite default computation backend.",
|
532
|
+
)
|
307
533
|
filter_group = parser.add_argument_group("Filters")
|
308
534
|
filter_group.add_argument(
|
309
|
-
"--
|
310
|
-
dest="
|
535
|
+
"--lowpass",
|
536
|
+
dest="lowpass",
|
311
537
|
type=float,
|
312
538
|
required=False,
|
313
|
-
help="
|
539
|
+
help="Resolution to lowpass filter template and target to in the same unit "
|
540
|
+
"as the sampling rate of template and target (typically Ångstrom).",
|
541
|
+
)
|
542
|
+
filter_group.add_argument(
|
543
|
+
"--highpass",
|
544
|
+
dest="highpass",
|
545
|
+
type=float,
|
546
|
+
required=False,
|
547
|
+
help="Resolution to highpass filter template and target to in the same unit "
|
548
|
+
"as the sampling rate of template and target (typically Ångstrom).",
|
549
|
+
)
|
550
|
+
filter_group.add_argument(
|
551
|
+
"--no_pass_smooth",
|
552
|
+
dest="no_pass_smooth",
|
553
|
+
action="store_false",
|
554
|
+
default=True,
|
555
|
+
help="Whether a hard edge filter should be used for --lowpass and --highpass.",
|
314
556
|
)
|
315
557
|
filter_group.add_argument(
|
316
|
-
"--
|
317
|
-
dest="
|
558
|
+
"--pass_format",
|
559
|
+
dest="pass_format",
|
318
560
|
type=str,
|
319
561
|
required=False,
|
320
|
-
|
321
|
-
"
|
562
|
+
default="sampling_rate",
|
563
|
+
choices=["sampling_rate", "voxel", "frequency"],
|
564
|
+
help="How values passed to --lowpass and --highpass should be interpreted. ",
|
322
565
|
)
|
323
566
|
filter_group.add_argument(
|
324
|
-
"--
|
325
|
-
dest="
|
326
|
-
|
567
|
+
"--whiten_spectrum",
|
568
|
+
dest="whiten_spectrum",
|
569
|
+
action="store_true",
|
570
|
+
default=None,
|
571
|
+
help="Apply spectral whitening to template and target based on target spectrum.",
|
572
|
+
)
|
573
|
+
filter_group.add_argument(
|
574
|
+
"--wedge_axes",
|
575
|
+
dest="wedge_axes",
|
576
|
+
type=str,
|
327
577
|
required=False,
|
328
578
|
default=None,
|
329
|
-
help="
|
579
|
+
help="Indices of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open "
|
580
|
+
"in z-direction and tilted over the x axis.",
|
330
581
|
)
|
331
582
|
filter_group.add_argument(
|
332
|
-
"--
|
333
|
-
dest="
|
583
|
+
"--tilt_angles",
|
584
|
+
dest="tilt_angles",
|
334
585
|
type=str,
|
335
586
|
required=False,
|
336
|
-
|
337
|
-
" to
|
587
|
+
default=None,
|
588
|
+
help="Path to a tab-separated file containing the column angles and optionally "
|
589
|
+
" weights, or comma separated start and stop stage tilt angle, e.g. 50,45, which "
|
590
|
+
" yields a continuous wedge mask. Alternatively, a tilt step size can be "
|
591
|
+
"specified like 50,45:5.0 to sample 5.0 degree tilt angle steps.",
|
338
592
|
)
|
339
593
|
filter_group.add_argument(
|
340
|
-
"--
|
341
|
-
dest="
|
342
|
-
type=
|
594
|
+
"--tilt_weighting",
|
595
|
+
dest="tilt_weighting",
|
596
|
+
type=str,
|
343
597
|
required=False,
|
598
|
+
choices=["angle", "relion", "grigorieff"],
|
344
599
|
default=None,
|
345
|
-
help="
|
346
|
-
"
|
600
|
+
help="Weighting scheme used to reweight individual tilts. Available options: "
|
601
|
+
"angle (cosine based weighting), "
|
602
|
+
"relion (relion formalism for wedge weighting) requires,"
|
603
|
+
"grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
|
604
|
+
"relion and grigorieff require electron doses in --tilt_angles weights column.",
|
347
605
|
)
|
348
606
|
filter_group.add_argument(
|
349
|
-
"--
|
350
|
-
dest="
|
607
|
+
"--reconstruction_filter",
|
608
|
+
dest="reconstruction_filter",
|
351
609
|
type=str,
|
352
610
|
required=False,
|
353
|
-
|
354
|
-
|
355
|
-
"
|
611
|
+
choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
|
612
|
+
default=None,
|
613
|
+
help="Filter applied when reconstructing (N+1)-D from N-D filters.",
|
356
614
|
)
|
357
615
|
filter_group.add_argument(
|
358
|
-
"--
|
359
|
-
dest="
|
616
|
+
"--reconstruction_interpolation_order",
|
617
|
+
dest="reconstruction_interpolation_order",
|
618
|
+
type=int,
|
619
|
+
default=1,
|
620
|
+
required=False,
|
621
|
+
help="Analogous to --interpolation_order but for reconstruction.",
|
622
|
+
)
|
623
|
+
filter_group.add_argument(
|
624
|
+
"--no_filter_target",
|
625
|
+
dest="no_filter_target",
|
626
|
+
action="store_true",
|
627
|
+
default=False,
|
628
|
+
help="Whether to not apply potential filters to the target.",
|
629
|
+
)
|
630
|
+
|
631
|
+
ctf_group = parser.add_argument_group("Contrast Transfer Function")
|
632
|
+
ctf_group.add_argument(
|
633
|
+
"--ctf_file",
|
634
|
+
dest="ctf_file",
|
635
|
+
type=str,
|
636
|
+
required=False,
|
637
|
+
default=None,
|
638
|
+
help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
|
639
|
+
"interpreted as tilt obtained at the angle specified in --tilt_angles. ",
|
640
|
+
)
|
641
|
+
ctf_group.add_argument(
|
642
|
+
"--defocus",
|
643
|
+
dest="defocus",
|
360
644
|
type=float,
|
361
645
|
required=False,
|
362
646
|
default=None,
|
363
|
-
help="
|
647
|
+
help="Defocus in units of sampling rate (typically Ångstrom). "
|
648
|
+
"Superseded by --ctf_file.",
|
649
|
+
)
|
650
|
+
ctf_group.add_argument(
|
651
|
+
"--phase_shift",
|
652
|
+
dest="phase_shift",
|
653
|
+
type=float,
|
654
|
+
required=False,
|
655
|
+
default=0,
|
656
|
+
help="Phase shift in degrees. Superseded by --ctf_file.",
|
657
|
+
)
|
658
|
+
ctf_group.add_argument(
|
659
|
+
"--acceleration_voltage",
|
660
|
+
dest="acceleration_voltage",
|
661
|
+
type=float,
|
662
|
+
required=False,
|
663
|
+
default=300,
|
664
|
+
help="Acceleration voltage in kV.",
|
665
|
+
)
|
666
|
+
ctf_group.add_argument(
|
667
|
+
"--spherical_aberration",
|
668
|
+
dest="spherical_aberration",
|
669
|
+
type=float,
|
670
|
+
required=False,
|
671
|
+
default=2.7e7,
|
672
|
+
help="Spherical aberration in units of sampling rate (typically Ångstrom).",
|
673
|
+
)
|
674
|
+
ctf_group.add_argument(
|
675
|
+
"--amplitude_contrast",
|
676
|
+
dest="amplitude_contrast",
|
677
|
+
type=float,
|
678
|
+
required=False,
|
679
|
+
default=0.07,
|
680
|
+
help="Amplitude contrast.",
|
681
|
+
)
|
682
|
+
ctf_group.add_argument(
|
683
|
+
"--no_flip_phase",
|
684
|
+
dest="no_flip_phase",
|
685
|
+
action="store_false",
|
686
|
+
required=False,
|
687
|
+
help="Do not perform phase-flipping CTF correction.",
|
688
|
+
)
|
689
|
+
ctf_group.add_argument(
|
690
|
+
"--correct_defocus_gradient",
|
691
|
+
dest="correct_defocus_gradient",
|
692
|
+
action="store_true",
|
693
|
+
required=False,
|
694
|
+
help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
|
695
|
+
"defocus gradients.",
|
364
696
|
)
|
365
697
|
|
366
698
|
performance_group = parser.add_argument_group("Performance")
|
@@ -403,14 +735,37 @@ def parse_args():
|
|
403
735
|
"for numerical stability. When working with very large targets, e.g. tomograms, "
|
404
736
|
"it is safe to use this flag and benefit from the performance gain.",
|
405
737
|
)
|
738
|
+
performance_group.add_argument(
|
739
|
+
"--no_filter_padding",
|
740
|
+
dest="no_filter_padding",
|
741
|
+
action="store_true",
|
742
|
+
default=False,
|
743
|
+
help="Omits padding of optional template filters. Particularly effective when "
|
744
|
+
"the target is much larger than the template. However, for fast osciliating "
|
745
|
+
"filters setting this flag can introduce aliasing effects.",
|
746
|
+
)
|
406
747
|
performance_group.add_argument(
|
407
748
|
"--interpolation_order",
|
408
749
|
dest="interpolation_order",
|
409
750
|
required=False,
|
410
751
|
type=int,
|
411
752
|
default=3,
|
412
|
-
help="Spline interpolation used for
|
413
|
-
|
753
|
+
help="Spline interpolation used for rotations.",
|
754
|
+
)
|
755
|
+
performance_group.add_argument(
|
756
|
+
"--use_mixed_precision",
|
757
|
+
dest="use_mixed_precision",
|
758
|
+
action="store_true",
|
759
|
+
default=False,
|
760
|
+
help="Use float16 for real values operations where possible.",
|
761
|
+
)
|
762
|
+
performance_group.add_argument(
|
763
|
+
"--use_memmap",
|
764
|
+
dest="use_memmap",
|
765
|
+
action="store_true",
|
766
|
+
default=False,
|
767
|
+
help="Use memmaps to offload large data objects to disk. "
|
768
|
+
"Particularly useful for large inputs in combination with --use_gpu.",
|
414
769
|
)
|
415
770
|
|
416
771
|
analyzer_group = parser.add_argument_group("Analyzer")
|
@@ -422,8 +777,22 @@ def parse_args():
|
|
422
777
|
default=0,
|
423
778
|
help="Minimum template matching scores to consider for analysis.",
|
424
779
|
)
|
425
|
-
|
780
|
+
analyzer_group.add_argument(
|
781
|
+
"-p",
|
782
|
+
dest="peak_calling",
|
783
|
+
action="store_true",
|
784
|
+
default=False,
|
785
|
+
help="Perform peak calling instead of score aggregation.",
|
786
|
+
)
|
787
|
+
analyzer_group.add_argument(
|
788
|
+
"--number_of_peaks",
|
789
|
+
dest="number_of_peaks",
|
790
|
+
action="store_true",
|
791
|
+
default=1000,
|
792
|
+
help="Number of peaks to call, 1000 by default.",
|
793
|
+
)
|
426
794
|
args = parser.parse_args()
|
795
|
+
args.version = __version__
|
427
796
|
|
428
797
|
if args.interpolation_order < 0:
|
429
798
|
args.interpolation_order = None
|
@@ -460,6 +829,21 @@ def parse_args():
|
|
460
829
|
int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
461
830
|
]
|
462
831
|
|
832
|
+
if args.tilt_angles is not None:
|
833
|
+
if args.wedge_axes is None:
|
834
|
+
raise ValueError("Need to specify --wedge_axes when --tilt_angles is set.")
|
835
|
+
if not exists(args.tilt_angles):
|
836
|
+
try:
|
837
|
+
float(args.tilt_angles.split(",")[0])
|
838
|
+
except ValueError:
|
839
|
+
raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
|
840
|
+
|
841
|
+
if args.ctf_file is not None and args.tilt_angles is None:
|
842
|
+
raise ValueError("Need to specify --tilt_angles when --ctf_file is set.")
|
843
|
+
|
844
|
+
if args.wedge_axes is not None:
|
845
|
+
args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
|
846
|
+
|
463
847
|
return args
|
464
848
|
|
465
849
|
|
@@ -477,12 +861,13 @@ def main():
|
|
477
861
|
sampling_rate=target.sampling_rate,
|
478
862
|
)
|
479
863
|
|
480
|
-
if
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
864
|
+
if target.sampling_rate.size == template.sampling_rate.size:
|
865
|
+
if not np.allclose(target.sampling_rate, template.sampling_rate):
|
866
|
+
print(
|
867
|
+
f"Resampling template to {target.sampling_rate}. "
|
868
|
+
"Consider providing a template with the same sampling rate as the target."
|
869
|
+
)
|
870
|
+
template = template.resample(target.sampling_rate, order=3)
|
486
871
|
|
487
872
|
template_mask = load_and_validate_mask(
|
488
873
|
mask_target=template, mask_path=args.template_mask
|
@@ -536,51 +921,6 @@ def main():
|
|
536
921
|
},
|
537
922
|
)
|
538
923
|
|
539
|
-
template_filter = {}
|
540
|
-
if args.gaussian_sigma is not None:
|
541
|
-
template.data = Preprocessor().gaussian_filter(
|
542
|
-
sigma=args.gaussian_sigma, template=template.data
|
543
|
-
)
|
544
|
-
|
545
|
-
if args.bandpass_band is not None:
|
546
|
-
bandpass_start, bandpass_stop = [
|
547
|
-
float(x) for x in args.bandpass_band.split(",")
|
548
|
-
]
|
549
|
-
if args.bandpass_smooth is None:
|
550
|
-
args.bandpass_smooth = 0
|
551
|
-
|
552
|
-
template_filter["bandpass_mask"] = {
|
553
|
-
"minimum_frequency": bandpass_start,
|
554
|
-
"maximum_frequency": bandpass_stop,
|
555
|
-
"gaussian_sigma": args.bandpass_smooth,
|
556
|
-
}
|
557
|
-
|
558
|
-
if args.tilt_range is not None:
|
559
|
-
args.wedge_smooth if args.wedge_smooth is not None else 0
|
560
|
-
tilt_start, tilt_stop = [float(x) for x in args.tilt_range.split(",")]
|
561
|
-
opening_axis, tilt_axis = [int(x) for x in args.wedge_axes.split(",")]
|
562
|
-
|
563
|
-
if args.tilt_step is not None:
|
564
|
-
template_filter["step_wedge_mask"] = {
|
565
|
-
"start_tilt": tilt_start,
|
566
|
-
"stop_tilt": tilt_stop,
|
567
|
-
"tilt_step": args.tilt_step,
|
568
|
-
"sigma": args.wedge_smooth,
|
569
|
-
"opening_axis": opening_axis,
|
570
|
-
"tilt_axis": tilt_axis,
|
571
|
-
"omit_negative_frequencies": True,
|
572
|
-
}
|
573
|
-
else:
|
574
|
-
template_filter["continuous_wedge_mask"] = {
|
575
|
-
"start_tilt": tilt_start,
|
576
|
-
"stop_tilt": tilt_stop,
|
577
|
-
"tilt_axis": tilt_axis,
|
578
|
-
"opening_axis": opening_axis,
|
579
|
-
"infinite_plane": True,
|
580
|
-
"sigma": args.wedge_smooth,
|
581
|
-
"omit_negative_frequencies": True,
|
582
|
-
}
|
583
|
-
|
584
924
|
if template_mask is None:
|
585
925
|
template_mask = template.empty
|
586
926
|
if not args.no_centering:
|
@@ -619,72 +959,97 @@ def main():
|
|
619
959
|
template.data, noise_proportion=1.0, normalize_power=True
|
620
960
|
)
|
621
961
|
|
622
|
-
|
962
|
+
# Determine suitable backend for the selected operation
|
963
|
+
available_backends = be.available_backends()
|
964
|
+
if args.backend is not None:
|
965
|
+
req_backend = args.backend
|
966
|
+
if req_backend not in available_backends:
|
967
|
+
raise ValueError("Requested backend is not available.")
|
968
|
+
available_backends = [req_backend]
|
969
|
+
|
970
|
+
be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
|
623
971
|
if args.use_gpu:
|
624
972
|
args.cores = len(args.gpu_indices)
|
625
|
-
|
626
|
-
|
973
|
+
be_selection = ("pytorch", "cupy", "jax")
|
974
|
+
if args.use_mixed_precision:
|
975
|
+
be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
|
627
976
|
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
preferred_backend = "pytorch"
|
636
|
-
if not has_torch:
|
637
|
-
preferred_backend = "cupy"
|
638
|
-
backend.change_backend(backend_name=preferred_backend, device="cuda")
|
639
|
-
else:
|
640
|
-
preferred_backend = "cupy"
|
641
|
-
if not has_cupy:
|
642
|
-
preferred_backend = "pytorch"
|
643
|
-
backend.change_backend(backend_name=preferred_backend, device="cuda")
|
644
|
-
if args.use_mixed_precision and preferred_backend == "pytorch":
|
977
|
+
available_backends = [x for x in available_backends if x in be_selection]
|
978
|
+
if args.peak_calling:
|
979
|
+
if "jax" in available_backends:
|
980
|
+
available_backends.remove("jax")
|
981
|
+
if args.use_gpu and "pytorch" in available_backends:
|
982
|
+
available_backends = ("pytorch",)
|
983
|
+
if args.interpolation_order == 3:
|
645
984
|
raise NotImplementedError(
|
646
|
-
"
|
647
|
-
" Consider installing CuPy to enable this feature."
|
648
|
-
)
|
649
|
-
elif args.use_mixed_precision:
|
650
|
-
backend.change_backend(
|
651
|
-
backend_name="cupy",
|
652
|
-
default_dtype=backend._array_backend.float16,
|
653
|
-
complex_dtype=backend._array_backend.complex64,
|
654
|
-
default_dtype_int=backend._array_backend.int16,
|
985
|
+
"Pytorch does not support --interpolation_order 3, 1 is supported."
|
655
986
|
)
|
656
|
-
|
657
|
-
|
658
|
-
|
987
|
+
# dim_match = len(template.shape) == len(target.shape) <= 3
|
988
|
+
# if dim_match and args.use_gpu and "jax" in available_backends:
|
989
|
+
# args.interpolation_order = 1
|
990
|
+
# available_backends = ["jax"]
|
659
991
|
|
992
|
+
backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
|
993
|
+
if args.use_gpu:
|
994
|
+
backend_preference = ("cupy", "pytorch", "jax")
|
995
|
+
for pref in backend_preference:
|
996
|
+
if pref not in available_backends:
|
997
|
+
continue
|
998
|
+
be.change_backend(pref)
|
999
|
+
if pref == "pytorch":
|
1000
|
+
be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
|
1001
|
+
|
1002
|
+
if args.use_mixed_precision:
|
1003
|
+
be.change_backend(
|
1004
|
+
backend_name=pref,
|
1005
|
+
default_dtype=be._array_backend.float16,
|
1006
|
+
complex_dtype=be._array_backend.complex64,
|
1007
|
+
default_dtype_int=be._array_backend.int16,
|
1008
|
+
)
|
1009
|
+
break
|
1010
|
+
|
1011
|
+
available_memory = be.get_available_memory() * be.device_count()
|
660
1012
|
if args.memory is None:
|
661
1013
|
args.memory = int(args.memory_scaling * available_memory)
|
662
1014
|
|
663
|
-
target_padding = np.zeros_like(template.shape)
|
664
|
-
if args.pad_target_edges:
|
665
|
-
target_padding = template.shape
|
666
|
-
|
667
|
-
template_box = template.shape
|
668
|
-
if not args.pad_fourier:
|
669
|
-
template_box = np.ones(len(template_box), dtype=int)
|
670
|
-
|
671
1015
|
callback_class = MaxScoreOverRotations
|
672
1016
|
if args.peak_calling:
|
673
1017
|
callback_class = PeakCallerMaximumFilter
|
674
1018
|
|
1019
|
+
matching_data = MatchingData(
|
1020
|
+
target=target,
|
1021
|
+
template=template.data,
|
1022
|
+
target_mask=target_mask,
|
1023
|
+
template_mask=template_mask,
|
1024
|
+
invert_target=args.invert_target_contrast,
|
1025
|
+
rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
|
1026
|
+
)
|
1027
|
+
|
1028
|
+
matching_data.template_filter, matching_data.target_filter = setup_filter(
|
1029
|
+
args, template, target
|
1030
|
+
)
|
1031
|
+
|
1032
|
+
template_box = matching_data._output_template_shape
|
1033
|
+
if not args.pad_fourier:
|
1034
|
+
template_box = tuple(0 for _ in range(len(template_box)))
|
1035
|
+
|
1036
|
+
target_padding = tuple(0 for _ in range(len(template_box)))
|
1037
|
+
if args.pad_target_edges:
|
1038
|
+
target_padding = matching_data._output_template_shape
|
1039
|
+
|
675
1040
|
splits, schedule = compute_parallelization_schedule(
|
676
1041
|
shape1=target.shape,
|
677
|
-
shape2=template_box,
|
678
|
-
shape1_padding=target_padding,
|
1042
|
+
shape2=tuple(int(x) for x in template_box),
|
1043
|
+
shape1_padding=tuple(int(x) for x in target_padding),
|
679
1044
|
max_cores=args.cores,
|
680
1045
|
max_ram=args.memory,
|
681
1046
|
split_only_outer=args.use_gpu,
|
682
1047
|
matching_method=args.score,
|
683
1048
|
analyzer_method=callback_class.__name__,
|
684
|
-
backend=
|
685
|
-
float_nbytes=
|
686
|
-
complex_nbytes=
|
687
|
-
integer_nbytes=
|
1049
|
+
backend=be._backend_name,
|
1050
|
+
float_nbytes=be.datatype_bytes(be._float_dtype),
|
1051
|
+
complex_nbytes=be.datatype_bytes(be._complex_dtype),
|
1052
|
+
integer_nbytes=be.datatype_bytes(be._int_dtype),
|
688
1053
|
)
|
689
1054
|
|
690
1055
|
if splits is None:
|
@@ -694,63 +1059,85 @@ def main():
|
|
694
1059
|
)
|
695
1060
|
exit(-1)
|
696
1061
|
|
697
|
-
analyzer_args = {
|
698
|
-
"score_threshold": args.score_threshold,
|
699
|
-
"number_of_peaks": 1000,
|
700
|
-
"convolution_mode": "valid",
|
701
|
-
"use_memmap": args.use_memmap,
|
702
|
-
}
|
703
|
-
|
704
1062
|
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
705
|
-
matching_data = MatchingData(target=target, template=template.data)
|
706
|
-
matching_data.rotations = get_rotation_matrices(
|
707
|
-
angular_sampling=args.angular_sampling, dim=target.data.ndim
|
708
|
-
)
|
709
|
-
if args.angular_sampling >= 180:
|
710
|
-
ndim = target.data.ndim
|
711
|
-
matching_data.rotations = np.eye(ndim).reshape(1, ndim, ndim)
|
712
|
-
|
713
|
-
matching_data.template_filter = template_filter
|
714
|
-
matching_data._invert_target = args.invert_target_contrast
|
715
|
-
if target_mask is not None:
|
716
|
-
matching_data.target_mask = target_mask
|
717
|
-
if template_mask is not None:
|
718
|
-
matching_data.template_mask = template_mask.data
|
719
|
-
|
720
1063
|
n_splits = np.prod(list(splits.values()))
|
721
1064
|
target_split = ", ".join(
|
722
1065
|
[":".join([str(x) for x in axis]) for axis in splits.items()]
|
723
1066
|
)
|
724
1067
|
gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
|
725
1068
|
options = {
|
726
|
-
"
|
727
|
-
"
|
728
|
-
"
|
729
|
-
"
|
730
|
-
"
|
1069
|
+
"Angular Sampling": f"{args.angular_sampling}"
|
1070
|
+
f" [{matching_data.rotations.shape[0]} rotations]",
|
1071
|
+
"Center Template": not args.no_centering,
|
1072
|
+
"Scramble Template": args.scramble_phases,
|
1073
|
+
"Invert Contrast": args.invert_target_contrast,
|
731
1074
|
"Extend Fourier Grid": not args.no_fourier_padding,
|
732
1075
|
"Extend Target Edges": not args.no_edge_padding,
|
733
1076
|
"Interpolation Order": args.interpolation_order,
|
734
|
-
"Score": f"{args.score}",
|
735
1077
|
"Setup Function": f"{get_func_fullname(matching_setup)}",
|
736
1078
|
"Scoring Function": f"{get_func_fullname(matching_score)}",
|
737
|
-
"Angular Sampling": f"{args.angular_sampling}"
|
738
|
-
f" [{matching_data.rotations.shape[0]} rotations]",
|
739
|
-
"Scramble Template": args.scramble_phases,
|
740
|
-
"Target Splits": f"{target_split} [N={n_splits}]",
|
741
1079
|
}
|
742
1080
|
|
743
1081
|
print_block(
|
744
|
-
name="Template Matching
|
1082
|
+
name="Template Matching",
|
745
1083
|
data=options,
|
746
|
-
label_width=max(len(key) for key in options.keys()) +
|
1084
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
747
1085
|
)
|
748
1086
|
|
749
|
-
|
1087
|
+
compute_options = {
|
1088
|
+
"Backend": be._BACKEND_REGISTRY[be._backend_name],
|
1089
|
+
"Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
|
1090
|
+
"Use Mixed Precision": args.use_mixed_precision,
|
1091
|
+
"Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
|
1092
|
+
"Temporary Directory": args.temp_directory,
|
1093
|
+
"Target Splits": f"{target_split} [N={n_splits}]",
|
1094
|
+
}
|
750
1095
|
print_block(
|
751
|
-
name="
|
752
|
-
data=
|
753
|
-
label_width=max(len(key) for key in options.keys()) +
|
1096
|
+
name="Computation",
|
1097
|
+
data=compute_options,
|
1098
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1099
|
+
)
|
1100
|
+
|
1101
|
+
filter_args = {
|
1102
|
+
"Lowpass": args.lowpass,
|
1103
|
+
"Highpass": args.highpass,
|
1104
|
+
"Smooth Pass": args.no_pass_smooth,
|
1105
|
+
"Pass Format": args.pass_format,
|
1106
|
+
"Spectral Whitening": args.whiten_spectrum,
|
1107
|
+
"Wedge Axes": args.wedge_axes,
|
1108
|
+
"Tilt Angles": args.tilt_angles,
|
1109
|
+
"Tilt Weighting": args.tilt_weighting,
|
1110
|
+
"Reconstruction Filter": args.reconstruction_filter,
|
1111
|
+
}
|
1112
|
+
if args.ctf_file is not None or args.defocus is not None:
|
1113
|
+
filter_args["CTF File"] = args.ctf_file
|
1114
|
+
filter_args["Defocus"] = args.defocus
|
1115
|
+
filter_args["Phase Shift"] = args.phase_shift
|
1116
|
+
filter_args["Flip Phase"] = args.no_flip_phase
|
1117
|
+
filter_args["Acceleration Voltage"] = args.acceleration_voltage
|
1118
|
+
filter_args["Spherical Aberration"] = args.spherical_aberration
|
1119
|
+
filter_args["Amplitude Contrast"] = args.amplitude_contrast
|
1120
|
+
filter_args["Correct Defocus"] = args.correct_defocus_gradient
|
1121
|
+
|
1122
|
+
filter_args = {k: v for k, v in filter_args.items() if v is not None}
|
1123
|
+
if len(filter_args):
|
1124
|
+
print_block(
|
1125
|
+
name="Filters",
|
1126
|
+
data=filter_args,
|
1127
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
1128
|
+
)
|
1129
|
+
|
1130
|
+
analyzer_args = {
|
1131
|
+
"score_threshold": args.score_threshold,
|
1132
|
+
"number_of_peaks": args.number_of_peaks,
|
1133
|
+
"min_distance": max(template.shape) // 2,
|
1134
|
+
"min_boundary_distance": max(template.shape) // 2,
|
1135
|
+
"use_memmap": args.use_memmap,
|
1136
|
+
}
|
1137
|
+
print_block(
|
1138
|
+
name="Analyzer",
|
1139
|
+
data={"Analyzer": callback_class, **analyzer_args},
|
1140
|
+
label_width=max(len(key) for key in options.keys()) + 3,
|
754
1141
|
)
|
755
1142
|
print("\n" + "-" * 80)
|
756
1143
|
|
@@ -771,6 +1158,7 @@ def main():
|
|
771
1158
|
target_splits=splits,
|
772
1159
|
pad_target_edges=args.pad_target_edges,
|
773
1160
|
pad_fourier=args.pad_fourier,
|
1161
|
+
pad_template_filter=not args.no_filter_padding,
|
774
1162
|
interpolation_order=args.interpolation_order,
|
775
1163
|
)
|
776
1164
|
|
@@ -780,19 +1168,18 @@ def main():
|
|
780
1168
|
candidates[0] *= target_mask.data
|
781
1169
|
with warnings.catch_warnings():
|
782
1170
|
warnings.simplefilter("ignore", category=UserWarning)
|
1171
|
+
nbytes = be.datatype_bytes(be._float_dtype)
|
1172
|
+
dtype = np.float32 if nbytes == 4 else np.float16
|
1173
|
+
rot_dim = matching_data.rotations.shape[1]
|
783
1174
|
candidates[3] = {
|
784
|
-
x:
|
785
|
-
np.frombuffer(i, dtype=matching_data.rotations.dtype).reshape(
|
786
|
-
candidates[0].ndim, candidates[0].ndim
|
787
|
-
)
|
788
|
-
)
|
1175
|
+
x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
|
789
1176
|
for i, x in candidates[3].items()
|
790
1177
|
}
|
791
|
-
|
792
|
-
candidates.append((target.origin, template.origin, target.sampling_rate, args))
|
1178
|
+
candidates.append((target.origin, template.origin, template.sampling_rate, args))
|
793
1179
|
write_pickle(data=candidates, filename=args.output)
|
794
1180
|
|
795
1181
|
runtime = time() - start
|
1182
|
+
print("\n" + "-" * 80)
|
796
1183
|
print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
|
797
1184
|
|
798
1185
|
|