pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,1135 @@
1
+ #!python
2
+ """ CLI for basic pyTME template matching functions.
3
+
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ import os
9
+ import argparse
10
+ import warnings
11
+ from sys import exit
12
+ from time import time
13
+ from typing import Tuple
14
+ from copy import deepcopy
15
+ from os.path import exists
16
+ from tempfile import gettempdir
17
+
18
+ import numpy as np
19
+
20
+ from tme.backends import backend as be
21
+ from tme import Density, __version__
22
+ from tme.matching_utils import scramble_phases, write_pickle
23
+ from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
24
+ from tme.rotations import (
25
+ get_cone_rotations,
26
+ get_rotation_matrices,
27
+ )
28
+ from tme.matching_data import MatchingData
29
+ from tme.analyzer import (
30
+ MaxScoreOverRotations,
31
+ PeakCallerMaximumFilter,
32
+ )
33
+ from tme.filters import (
34
+ CTF,
35
+ Wedge,
36
+ Compose,
37
+ BandPassFilter,
38
+ WedgeReconstructed,
39
+ ReconstructFromTilt,
40
+ LinearWhiteningFilter,
41
+ )
42
+
43
+
44
+ def get_func_fullname(func) -> str:
45
+ """Returns the full name of the given function, including its module."""
46
+ return f"<function '{func.__module__}.{func.__name__}'>"
47
+
48
+
49
+ def print_block(name: str, data: dict, label_width=20) -> None:
50
+ """Prints a formatted block of information."""
51
+ print(f"\n> {name}")
52
+ for key, value in data.items():
53
+ if isinstance(value, np.ndarray):
54
+ value = value.shape
55
+ formatted_value = str(value)
56
+ print(f" - {key + ':':<{label_width}} {formatted_value}")
57
+
58
+
59
+ def print_entry() -> None:
60
+ width = 80
61
+ text = f" pytme v{__version__} "
62
+ padding_total = width - len(text) - 2
63
+ padding_left = padding_total // 2
64
+ padding_right = padding_total - padding_left
65
+
66
+ print("*" * width)
67
+ print(f"*{ ' ' * padding_left }{text}{ ' ' * padding_right }*")
68
+ print("*" * width)
69
+
70
+
71
+ def check_positive(value):
72
+ ivalue = float(value)
73
+ if ivalue <= 0:
74
+ raise argparse.ArgumentTypeError("%s is an invalid positive float." % value)
75
+ return ivalue
76
+
77
+
78
+ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
79
+ """
80
+ Loadsa mask in CCP4/MRC format and assess whether the sampling_rate
81
+ and shape matches its target.
82
+
83
+ Parameters
84
+ ----------
85
+ mask_target : Density
86
+ Object the mask should be applied to
87
+ mask_path : str
88
+ Path to the mask in CCP4/MRC format.
89
+ kwargs : dict, optional
90
+ Keyword arguments passed to :py:meth:`tme.density.Density.from_file`.
91
+ Raise
92
+ -----
93
+ ValueError
94
+ If shape or sampling rate do not match between mask_target and mask
95
+
96
+ Returns
97
+ -------
98
+ Density
99
+ A density instance if the mask was validated and loaded otherwise None
100
+ """
101
+ mask = mask_path
102
+ if mask is not None:
103
+ mask = Density.from_file(mask, **kwargs)
104
+ mask.origin = deepcopy(mask_target.origin)
105
+ if not np.allclose(mask.shape, mask_target.shape):
106
+ raise ValueError(
107
+ f"Expected shape of {mask_path} was {mask_target.shape},"
108
+ f" got f{mask.shape}"
109
+ )
110
+ if not np.allclose(
111
+ np.round(mask.sampling_rate, 2), np.round(mask_target.sampling_rate, 2)
112
+ ):
113
+ raise ValueError(
114
+ f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
115
+ f", got f{mask.sampling_rate}"
116
+ )
117
+ return mask
118
+
119
+
120
+ def parse_rotation_logic(args, ndim):
121
+ if args.angular_sampling is not None:
122
+ rotations = get_rotation_matrices(
123
+ angular_sampling=args.angular_sampling,
124
+ dim=ndim,
125
+ use_optimized_set=not args.no_use_optimized_set,
126
+ )
127
+ if args.angular_sampling >= 180:
128
+ rotations = np.eye(ndim).reshape(1, ndim, ndim)
129
+ return rotations
130
+
131
+ if args.axis_sampling is None:
132
+ args.axis_sampling = args.cone_sampling
133
+
134
+ rotations = get_cone_rotations(
135
+ cone_angle=args.cone_angle,
136
+ cone_sampling=args.cone_sampling,
137
+ axis_angle=args.axis_angle,
138
+ axis_sampling=args.axis_sampling,
139
+ n_symmetry=args.axis_symmetry,
140
+ axis=[0 if i != args.cone_axis else 1 for i in range(ndim)],
141
+ reference=[0, 0, -1],
142
+ )
143
+ return rotations
144
+
145
+
146
+ def compute_schedule(
147
+ args,
148
+ matching_data: MatchingData,
149
+ callback_class,
150
+ pad_edges: bool = False,
151
+ ):
152
+ # User requested target padding
153
+ if args.pad_edges is True:
154
+ pad_edges = True
155
+
156
+ splits, schedule = matching_data.computation_schedule(
157
+ matching_method=args.score,
158
+ analyzer_method=callback_class.__name__,
159
+ use_gpu=args.use_gpu,
160
+ pad_fourier=False,
161
+ pad_target_edges=pad_edges,
162
+ available_memory=args.memory,
163
+ max_cores=args.cores,
164
+ )
165
+
166
+ if splits is None:
167
+ print(
168
+ "Found no suitable parallelization schedule. Consider increasing"
169
+ " available RAM or decreasing number of cores."
170
+ )
171
+ exit(-1)
172
+
173
+ n_splits = np.prod(list(splits.values()))
174
+ if pad_edges is False and len(matching_data._target_dim) == 0 and n_splits > 1:
175
+ args.pad_edges = True
176
+ return compute_schedule(args, matching_data, callback_class, True)
177
+ return splits, schedule
178
+
179
+
180
+ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
181
+ needs_reconstruction = False
182
+ template_filter, target_filter = [], []
183
+ if args.tilt_angles is not None:
184
+ needs_reconstruction = args.tilt_weighting is not None
185
+ try:
186
+ wedge = Wedge.from_file(args.tilt_angles)
187
+ wedge.weight_type = args.tilt_weighting
188
+ if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
189
+ wedge = WedgeReconstructed(
190
+ angles=wedge.angles,
191
+ weight_wedge=args.tilt_weighting == "angle",
192
+ opening_axis=args.wedge_axes[0],
193
+ tilt_axis=args.wedge_axes[1],
194
+ )
195
+ except FileNotFoundError:
196
+ tilt_step, create_continuous_wedge = None, True
197
+ tilt_start, tilt_stop = args.tilt_angles.split(",")
198
+ if ":" in tilt_stop:
199
+ create_continuous_wedge = False
200
+ tilt_stop, tilt_step = tilt_stop.split(":")
201
+ tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
202
+ tilt_angles = (tilt_start, tilt_stop)
203
+ if tilt_step is not None:
204
+ tilt_step = float(tilt_step)
205
+ tilt_angles = np.arange(
206
+ -tilt_start, tilt_stop + tilt_step, tilt_step
207
+ ).tolist()
208
+
209
+ if args.tilt_weighting is not None and tilt_step is None:
210
+ raise ValueError(
211
+ "Tilt weighting is not supported for continuous wedges."
212
+ )
213
+ if args.tilt_weighting not in ("angle", None):
214
+ raise ValueError(
215
+ "Tilt weighting schemes other than 'angle' or 'None' require "
216
+ "a specification of electron doses via --tilt_angles."
217
+ )
218
+
219
+ wedge = Wedge(
220
+ angles=tilt_angles,
221
+ opening_axis=args.wedge_axes[0],
222
+ tilt_axis=args.wedge_axes[1],
223
+ shape=None,
224
+ weight_type=None,
225
+ weights=np.ones_like(tilt_angles),
226
+ )
227
+ if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
228
+ wedge = WedgeReconstructed(
229
+ angles=tilt_angles,
230
+ weight_wedge=args.tilt_weighting == "angle",
231
+ create_continuous_wedge=create_continuous_wedge,
232
+ reconstruction_filter=args.reconstruction_filter,
233
+ opening_axis=args.wedge_axes[0],
234
+ tilt_axis=args.wedge_axes[1],
235
+ )
236
+ wedge_target = WedgeReconstructed(
237
+ angles=(np.abs(np.min(tilt_angles)), np.abs(np.max(tilt_angles))),
238
+ weight_wedge=False,
239
+ create_continuous_wedge=True,
240
+ opening_axis=args.wedge_axes[0],
241
+ tilt_axis=args.wedge_axes[1],
242
+ )
243
+ target_filter.append(wedge_target)
244
+
245
+ wedge.sampling_rate = template.sampling_rate
246
+ template_filter.append(wedge)
247
+ if not isinstance(wedge, WedgeReconstructed):
248
+ reconstruction_filter = ReconstructFromTilt(
249
+ reconstruction_filter=args.reconstruction_filter,
250
+ interpolation_order=args.reconstruction_interpolation_order,
251
+ )
252
+ template_filter.append(reconstruction_filter)
253
+
254
+ if args.ctf_file is not None or args.defocus is not None:
255
+ needs_reconstruction = True
256
+ if args.ctf_file is not None:
257
+ ctf = CTF.from_file(args.ctf_file)
258
+ n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
259
+ if n_tilts_ctfs != n_tils_angles:
260
+ raise ValueError(
261
+ f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
262
+ f"recieved {n_tils_angles} tilt angles. Expected one angle "
263
+ "per micrograph."
264
+ )
265
+ ctf.angles = wedge.angles
266
+ ctf.no_reconstruction = False
267
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
268
+ else:
269
+ needs_reconstruction = False
270
+ ctf = CTF(
271
+ defocus_x=[args.defocus],
272
+ phase_shift=[args.phase_shift],
273
+ defocus_y=None,
274
+ angles=[0],
275
+ shape=None,
276
+ return_real_fourier=True,
277
+ )
278
+ ctf.sampling_rate = template.sampling_rate
279
+ ctf.flip_phase = args.no_flip_phase
280
+ ctf.amplitude_contrast = args.amplitude_contrast
281
+ ctf.spherical_aberration = args.spherical_aberration
282
+ ctf.acceleration_voltage = args.acceleration_voltage * 1e3
283
+ ctf.correct_defocus_gradient = args.correct_defocus_gradient
284
+
285
+ if not needs_reconstruction:
286
+ template_filter.append(ctf)
287
+ elif isinstance(template_filter[-1], ReconstructFromTilt):
288
+ template_filter.insert(-1, ctf)
289
+ else:
290
+ template_filter.insert(0, ctf)
291
+ template_filter.insert(
292
+ 1,
293
+ ReconstructFromTilt(
294
+ reconstruction_filter=args.reconstruction_filter,
295
+ interpolation_order=args.reconstruction_interpolation_order,
296
+ ),
297
+ )
298
+
299
+ if args.lowpass or args.highpass is not None:
300
+ lowpass, highpass = args.lowpass, args.highpass
301
+ if args.pass_format == "voxel":
302
+ if lowpass is not None:
303
+ lowpass = np.max(np.multiply(lowpass, template.sampling_rate))
304
+ if highpass is not None:
305
+ highpass = np.max(np.multiply(highpass, template.sampling_rate))
306
+ elif args.pass_format == "frequency":
307
+ if lowpass is not None:
308
+ lowpass = np.max(np.divide(template.sampling_rate, lowpass))
309
+ if highpass is not None:
310
+ highpass = np.max(np.divide(template.sampling_rate, highpass))
311
+
312
+ try:
313
+ if args.lowpass >= args.highpass:
314
+ warnings.warn("--lowpass should be smaller than --highpass.")
315
+ except Exception:
316
+ pass
317
+
318
+ bandpass = BandPassFilter(
319
+ use_gaussian=args.no_pass_smooth,
320
+ lowpass=lowpass,
321
+ highpass=highpass,
322
+ sampling_rate=template.sampling_rate,
323
+ )
324
+ template_filter.append(bandpass)
325
+ target_filter.append(bandpass)
326
+
327
+ if args.whiten_spectrum:
328
+ whitening_filter = LinearWhiteningFilter()
329
+ template_filter.append(whitening_filter)
330
+ target_filter.append(whitening_filter)
331
+
332
+ if needs_reconstruction and args.reconstruction_filter is None:
333
+ warnings.warn(
334
+ "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
335
+ )
336
+
337
+ template_filter = Compose(template_filter) if len(template_filter) else None
338
+ target_filter = Compose(target_filter) if len(target_filter) else None
339
+ if args.no_filter_target:
340
+ target_filter = None
341
+
342
+ return template_filter, target_filter
343
+
344
+
345
+ def parse_args():
346
+ parser = argparse.ArgumentParser(
347
+ description="Perform template matching.",
348
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
349
+ )
350
+
351
+ io_group = parser.add_argument_group("Input / Output")
352
+ io_group.add_argument(
353
+ "-m",
354
+ "--target",
355
+ dest="target",
356
+ type=str,
357
+ required=True,
358
+ help="Path to a target in CCP4/MRC, EM, H5 or another format supported by "
359
+ "tme.density.Density.from_file "
360
+ "https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
361
+ )
362
+ io_group.add_argument(
363
+ "--target_mask",
364
+ dest="target_mask",
365
+ type=str,
366
+ required=False,
367
+ help="Path to a mask for the target in a supported format (see target).",
368
+ )
369
+ io_group.add_argument(
370
+ "-i",
371
+ "--template",
372
+ dest="template",
373
+ type=str,
374
+ required=True,
375
+ help="Path to a template in PDB/MMCIF or other supported formats (see target).",
376
+ )
377
+ io_group.add_argument(
378
+ "--template_mask",
379
+ dest="template_mask",
380
+ type=str,
381
+ required=False,
382
+ help="Path to a mask for the template in a supported format (see target).",
383
+ )
384
+ io_group.add_argument(
385
+ "-o",
386
+ "--output",
387
+ dest="output",
388
+ type=str,
389
+ required=False,
390
+ default="output.pickle",
391
+ help="Path to the output pickle file.",
392
+ )
393
+ io_group.add_argument(
394
+ "--invert_target_contrast",
395
+ dest="invert_target_contrast",
396
+ action="store_true",
397
+ default=False,
398
+ help="Invert the target's contrast for cases where templates to-be-matched have "
399
+ "negative values, e.g. tomograms.",
400
+ )
401
+ io_group.add_argument(
402
+ "--scramble_phases",
403
+ dest="scramble_phases",
404
+ action="store_true",
405
+ default=False,
406
+ help="Phase scramble the template to generate a noise score background.",
407
+ )
408
+
409
+ scoring_group = parser.add_argument_group("Scoring")
410
+ scoring_group.add_argument(
411
+ "-s",
412
+ dest="score",
413
+ type=str,
414
+ default="FLCSphericalMask",
415
+ choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
416
+ help="Template matching scoring function.",
417
+ )
418
+
419
+ angular_group = parser.add_argument_group("Angular Sampling")
420
+ angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
421
+
422
+ angular_exclusive.add_argument(
423
+ "-a",
424
+ dest="angular_sampling",
425
+ type=check_positive,
426
+ default=None,
427
+ help="Angular sampling rate using optimized rotational sets."
428
+ "A lower number yields more rotations. Values >= 180 sample only the identity.",
429
+ )
430
+ angular_exclusive.add_argument(
431
+ "--cone_angle",
432
+ dest="cone_angle",
433
+ type=check_positive,
434
+ default=None,
435
+ help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
436
+ "narrow interval around a known orientation, e.g. for surface oversampling.",
437
+ )
438
+ angular_group.add_argument(
439
+ "--cone_axis",
440
+ dest="cone_axis",
441
+ type=check_positive,
442
+ default=2,
443
+ help="Principal axis to build cone around.",
444
+ )
445
+ angular_group.add_argument(
446
+ "--invert_cone",
447
+ dest="invert_cone",
448
+ action="store_true",
449
+ help="Invert cone handedness.",
450
+ )
451
+ angular_group.add_argument(
452
+ "--cone_sampling",
453
+ dest="cone_sampling",
454
+ type=check_positive,
455
+ default=None,
456
+ help="Sampling rate of the cone in degrees.",
457
+ )
458
+ angular_group.add_argument(
459
+ "--axis_angle",
460
+ dest="axis_angle",
461
+ type=check_positive,
462
+ default=360.0,
463
+ required=False,
464
+ help="Sampling angle along the z-axis of the cone.",
465
+ )
466
+ angular_group.add_argument(
467
+ "--axis_sampling",
468
+ dest="axis_sampling",
469
+ type=check_positive,
470
+ default=None,
471
+ required=False,
472
+ help="Sampling rate along the z-axis of the cone. Defaults to --cone_sampling.",
473
+ )
474
+ angular_group.add_argument(
475
+ "--axis_symmetry",
476
+ dest="axis_symmetry",
477
+ type=check_positive,
478
+ default=1,
479
+ required=False,
480
+ help="N-fold symmetry around z-axis of the cone.",
481
+ )
482
+ angular_group.add_argument(
483
+ "--no_use_optimized_set",
484
+ dest="no_use_optimized_set",
485
+ action="store_true",
486
+ default=False,
487
+ required=False,
488
+ help="Whether to use random uniform instead of optimized rotation sets.",
489
+ )
490
+
491
+ computation_group = parser.add_argument_group("Computation")
492
+ computation_group.add_argument(
493
+ "-n",
494
+ dest="cores",
495
+ required=False,
496
+ type=int,
497
+ default=4,
498
+ help="Number of cores used for template matching.",
499
+ )
500
+ computation_group.add_argument(
501
+ "--use_gpu",
502
+ dest="use_gpu",
503
+ action="store_true",
504
+ default=False,
505
+ help="Whether to perform computations on the GPU.",
506
+ )
507
+ computation_group.add_argument(
508
+ "--gpu_indices",
509
+ dest="gpu_indices",
510
+ type=str,
511
+ default=None,
512
+ help="Comma-separated list of GPU indices to use. For example,"
513
+ " 0,1 for the first and second GPU. Only used if --use_gpu is set."
514
+ " If not provided but --use_gpu is set, CUDA_VISIBLE_DEVICES will"
515
+ " be respected.",
516
+ )
517
+ computation_group.add_argument(
518
+ "-r",
519
+ "--ram",
520
+ dest="memory",
521
+ required=False,
522
+ type=int,
523
+ default=None,
524
+ help="Amount of memory that can be used in bytes.",
525
+ )
526
+ computation_group.add_argument(
527
+ "--memory_scaling",
528
+ dest="memory_scaling",
529
+ required=False,
530
+ type=float,
531
+ default=0.85,
532
+ help="Fraction of available memory to be used. Ignored if --ram is set.",
533
+ )
534
+ computation_group.add_argument(
535
+ "--temp_directory",
536
+ dest="temp_directory",
537
+ default=None,
538
+ help="Directory for temporary objects. Faster I/O improves runtime.",
539
+ )
540
+ computation_group.add_argument(
541
+ "--backend",
542
+ dest="backend",
543
+ default=None,
544
+ choices=be.available_backends(),
545
+ help="[Expert] Overwrite default computation backend.",
546
+ )
547
+ filter_group = parser.add_argument_group("Filters")
548
+ filter_group.add_argument(
549
+ "--lowpass",
550
+ dest="lowpass",
551
+ type=float,
552
+ required=False,
553
+ help="Resolution to lowpass filter template and target to in the same unit "
554
+ "as the sampling rate of template and target (typically Ångstrom).",
555
+ )
556
+ filter_group.add_argument(
557
+ "--highpass",
558
+ dest="highpass",
559
+ type=float,
560
+ required=False,
561
+ help="Resolution to highpass filter template and target to in the same unit "
562
+ "as the sampling rate of template and target (typically Ångstrom).",
563
+ )
564
+ filter_group.add_argument(
565
+ "--no_pass_smooth",
566
+ dest="no_pass_smooth",
567
+ action="store_false",
568
+ default=True,
569
+ help="Whether a hard edge filter should be used for --lowpass and --highpass.",
570
+ )
571
+ filter_group.add_argument(
572
+ "--pass_format",
573
+ dest="pass_format",
574
+ type=str,
575
+ required=False,
576
+ default="sampling_rate",
577
+ choices=["sampling_rate", "voxel", "frequency"],
578
+ help="How values passed to --lowpass and --highpass should be interpreted. ",
579
+ )
580
+ filter_group.add_argument(
581
+ "--whiten_spectrum",
582
+ dest="whiten_spectrum",
583
+ action="store_true",
584
+ default=None,
585
+ help="Apply spectral whitening to template and target based on target spectrum.",
586
+ )
587
+ filter_group.add_argument(
588
+ "--wedge_axes",
589
+ dest="wedge_axes",
590
+ type=str,
591
+ required=False,
592
+ default=None,
593
+ help="Indices of wedge opening and tilt axis, e.g. '2,0' for a wedge open "
594
+ "in z and tilted over the x-axis.",
595
+ )
596
+ filter_group.add_argument(
597
+ "--tilt_angles",
598
+ dest="tilt_angles",
599
+ type=str,
600
+ required=False,
601
+ default=None,
602
+ help="Path to a tab-separated file containing the column angles and optionally "
603
+ " weights, or comma separated start and stop stage tilt angle, e.g. 50,45, which "
604
+ " yields a continuous wedge mask. Alternatively, a tilt step size can be "
605
+ "specified like 50,45:5.0 to sample 5.0 degree tilt angle steps.",
606
+ )
607
+ filter_group.add_argument(
608
+ "--tilt_weighting",
609
+ dest="tilt_weighting",
610
+ type=str,
611
+ required=False,
612
+ choices=["angle", "relion", "grigorieff"],
613
+ default=None,
614
+ help="Weighting scheme used to reweight individual tilts. Available options: "
615
+ "angle (cosine based weighting), "
616
+ "relion (relion formalism for wedge weighting) requires,"
617
+ "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
618
+ "relion and grigorieff require electron doses in --tilt_angles weights column.",
619
+ )
620
+ filter_group.add_argument(
621
+ "--reconstruction_filter",
622
+ dest="reconstruction_filter",
623
+ type=str,
624
+ required=False,
625
+ choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
626
+ default=None,
627
+ help="Filter applied when reconstructing (N+1)-D from N-D filters.",
628
+ )
629
+ filter_group.add_argument(
630
+ "--reconstruction_interpolation_order",
631
+ dest="reconstruction_interpolation_order",
632
+ type=int,
633
+ default=1,
634
+ required=False,
635
+ help="Analogous to --interpolation_order but for reconstruction.",
636
+ )
637
+ filter_group.add_argument(
638
+ "--no_filter_target",
639
+ dest="no_filter_target",
640
+ action="store_true",
641
+ default=False,
642
+ help="Whether to not apply potential filters to the target.",
643
+ )
644
+
645
+ ctf_group = parser.add_argument_group("Contrast Transfer Function")
646
+ ctf_group.add_argument(
647
+ "--ctf_file",
648
+ dest="ctf_file",
649
+ type=str,
650
+ required=False,
651
+ default=None,
652
+ help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
653
+ "interpreted as tilt obtained at the angle specified in --tilt_angles. ",
654
+ )
655
+ ctf_group.add_argument(
656
+ "--defocus",
657
+ dest="defocus",
658
+ type=float,
659
+ required=False,
660
+ default=None,
661
+ help="Defocus in units of sampling rate (typically Ångstrom). "
662
+ "Superseded by --ctf_file.",
663
+ )
664
+ ctf_group.add_argument(
665
+ "--phase_shift",
666
+ dest="phase_shift",
667
+ type=float,
668
+ required=False,
669
+ default=0,
670
+ help="Phase shift in degrees. Superseded by --ctf_file.",
671
+ )
672
+ ctf_group.add_argument(
673
+ "--acceleration_voltage",
674
+ dest="acceleration_voltage",
675
+ type=float,
676
+ required=False,
677
+ default=300,
678
+ help="Acceleration voltage in kV.",
679
+ )
680
+ ctf_group.add_argument(
681
+ "--spherical_aberration",
682
+ dest="spherical_aberration",
683
+ type=float,
684
+ required=False,
685
+ default=2.7e7,
686
+ help="Spherical aberration in units of sampling rate (typically Ångstrom).",
687
+ )
688
+ ctf_group.add_argument(
689
+ "--amplitude_contrast",
690
+ dest="amplitude_contrast",
691
+ type=float,
692
+ required=False,
693
+ default=0.07,
694
+ help="Amplitude contrast.",
695
+ )
696
+ ctf_group.add_argument(
697
+ "--no_flip_phase",
698
+ dest="no_flip_phase",
699
+ action="store_false",
700
+ required=False,
701
+ help="Do not perform phase-flipping CTF correction.",
702
+ )
703
+ ctf_group.add_argument(
704
+ "--correct_defocus_gradient",
705
+ dest="correct_defocus_gradient",
706
+ action="store_true",
707
+ required=False,
708
+ help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
709
+ "defocus gradients.",
710
+ )
711
+
712
+ performance_group = parser.add_argument_group("Performance")
713
+ performance_group.add_argument(
714
+ "--no_centering",
715
+ dest="no_centering",
716
+ action="store_true",
717
+ help="Assumes the template is already centered and omits centering.",
718
+ )
719
+ performance_group.add_argument(
720
+ "--pad_edges",
721
+ dest="pad_edges",
722
+ action="store_true",
723
+ default=False,
724
+ help="Whether to pad the edges of the target. Useful if the target does not "
725
+ "a well-defined bounding box. Defaults to True if splitting is required.",
726
+ )
727
+ performance_group.add_argument(
728
+ "--pad_filter",
729
+ dest="pad_filter",
730
+ action="store_true",
731
+ default=False,
732
+ help="Pads the filter to the shape of the target. Particularly useful for fast "
733
+ "oscilating filters to avoid aliasing effects.",
734
+ )
735
+ performance_group.add_argument(
736
+ "--interpolation_order",
737
+ dest="interpolation_order",
738
+ required=False,
739
+ type=int,
740
+ default=3,
741
+ help="Spline interpolation used for rotations.",
742
+ )
743
+ performance_group.add_argument(
744
+ "--use_mixed_precision",
745
+ dest="use_mixed_precision",
746
+ action="store_true",
747
+ default=False,
748
+ help="Use float16 for real values operations where possible.",
749
+ )
750
+ performance_group.add_argument(
751
+ "--use_memmap",
752
+ dest="use_memmap",
753
+ action="store_true",
754
+ default=False,
755
+ help="Use memmaps to offload large data objects to disk. "
756
+ "Particularly useful for large inputs in combination with --use_gpu.",
757
+ )
758
+
759
+ analyzer_group = parser.add_argument_group("Analyzer")
760
+ analyzer_group.add_argument(
761
+ "--score_threshold",
762
+ dest="score_threshold",
763
+ required=False,
764
+ type=float,
765
+ default=0,
766
+ help="Minimum template matching scores to consider for analysis.",
767
+ )
768
+ analyzer_group.add_argument(
769
+ "-p",
770
+ dest="peak_calling",
771
+ action="store_true",
772
+ default=False,
773
+ help="Perform peak calling instead of score aggregation.",
774
+ )
775
+ analyzer_group.add_argument(
776
+ "--number_of_peaks",
777
+ dest="number_of_peaks",
778
+ action="store_true",
779
+ default=1000,
780
+ help="Number of peaks to call, 1000 by default.",
781
+ )
782
+ args = parser.parse_args()
783
+ args.version = __version__
784
+
785
+ if args.interpolation_order < 0:
786
+ args.interpolation_order = None
787
+
788
+ if args.temp_directory is None:
789
+ args.temp_directory = gettempdir()
790
+
791
+ os.environ["TMPDIR"] = args.temp_directory
792
+ if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
793
+ raise ValueError(
794
+ f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
795
+ )
796
+
797
+ gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
798
+ if args.gpu_indices is not None:
799
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_indices
800
+
801
+ if args.use_gpu:
802
+ gpu_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
803
+ if gpu_devices is None:
804
+ print(
805
+ "No GPU indices provided and CUDA_VISIBLE_DEVICES is not set.",
806
+ "Assuming device 0.",
807
+ )
808
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
809
+ args.gpu_indices = [
810
+ int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
811
+ ]
812
+
813
+ if args.tilt_angles is not None:
814
+ if args.wedge_axes is None:
815
+ raise ValueError("Need to specify --wedge_axes when --tilt_angles is set.")
816
+ if not exists(args.tilt_angles):
817
+ try:
818
+ float(args.tilt_angles.split(",")[0])
819
+ except ValueError:
820
+ raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
821
+
822
+ if args.ctf_file is not None and args.tilt_angles is None:
823
+ raise ValueError("Need to specify --tilt_angles when --ctf_file is set.")
824
+
825
+ if args.wedge_axes is not None:
826
+ args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
827
+
828
+ return args
829
+
830
+
831
+ def main():
832
+ args = parse_args()
833
+ print_entry()
834
+
835
+ target = Density.from_file(args.target, use_memmap=True)
836
+
837
+ try:
838
+ template = Density.from_file(args.template)
839
+ except Exception:
840
+ template = Density.from_structure(
841
+ filename_or_structure=args.template,
842
+ sampling_rate=target.sampling_rate,
843
+ )
844
+
845
+ if target.sampling_rate.size == template.sampling_rate.size:
846
+ if not np.allclose(
847
+ np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
848
+ ):
849
+ print(
850
+ f"Resampling template to {target.sampling_rate}. "
851
+ "Consider providing a template with the same sampling rate as the target."
852
+ )
853
+ template = template.resample(target.sampling_rate, order=3)
854
+
855
+ template_mask = load_and_validate_mask(
856
+ mask_target=template, mask_path=args.template_mask
857
+ )
858
+ target_mask = load_and_validate_mask(
859
+ mask_target=target, mask_path=args.target_mask, use_memmap=True
860
+ )
861
+
862
+ initial_shape = target.shape
863
+ print_block(
864
+ name="Target",
865
+ data={
866
+ "Initial Shape": initial_shape,
867
+ "Sampling Rate": tuple(np.round(target.sampling_rate, 2)),
868
+ "Final Shape": target.shape,
869
+ },
870
+ )
871
+
872
+ if target_mask:
873
+ print_block(
874
+ name="Target Mask",
875
+ data={
876
+ "Initial Shape": initial_shape,
877
+ "Sampling Rate": tuple(np.round(target_mask.sampling_rate, 2)),
878
+ "Final Shape": target_mask.shape,
879
+ },
880
+ )
881
+
882
+ initial_shape = template.shape
883
+ translation = np.zeros(len(template.shape), dtype=np.float32)
884
+ if not args.no_centering:
885
+ template, translation = template.centered(0)
886
+ print_block(
887
+ name="Template",
888
+ data={
889
+ "Initial Shape": initial_shape,
890
+ "Sampling Rate": tuple(np.round(template.sampling_rate, 2)),
891
+ "Final Shape": template.shape,
892
+ },
893
+ )
894
+
895
+ if template_mask is None:
896
+ template_mask = template.empty
897
+ if not args.no_centering:
898
+ enclosing_box = template.minimum_enclosing_box(
899
+ 0, use_geometric_center=False
900
+ )
901
+ template_mask.adjust_box(enclosing_box)
902
+
903
+ template_mask.data[:] = 1
904
+ translation = np.zeros_like(translation)
905
+
906
+ template_mask.pad(template.shape, center=False)
907
+ origin_translation = np.divide(
908
+ np.subtract(template.origin, template_mask.origin), template.sampling_rate
909
+ )
910
+ translation = np.add(translation, origin_translation)
911
+
912
+ template_mask = template_mask.rigid_transform(
913
+ rotation_matrix=np.eye(template_mask.data.ndim),
914
+ translation=-translation,
915
+ order=1,
916
+ )
917
+ template_mask.origin = template.origin.copy()
918
+ print_block(
919
+ name="Template Mask",
920
+ data={
921
+ "Inital Shape": initial_shape,
922
+ "Sampling Rate": tuple(np.round(template_mask.sampling_rate, 2)),
923
+ "Final Shape": template_mask.shape,
924
+ },
925
+ )
926
+ print("\n" + "-" * 80)
927
+
928
+ if args.scramble_phases:
929
+ template.data = scramble_phases(
930
+ template.data, noise_proportion=1.0, normalize_power=False
931
+ )
932
+
933
+ # Determine suitable backend for the selected operation
934
+ available_backends = be.available_backends()
935
+ if args.backend is not None:
936
+ req_backend = args.backend
937
+ if req_backend not in available_backends:
938
+ raise ValueError("Requested backend is not available.")
939
+ available_backends = [req_backend]
940
+
941
+ be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
942
+ if args.use_gpu:
943
+ args.cores = len(args.gpu_indices)
944
+ be_selection = ("pytorch", "cupy", "jax")
945
+ if args.use_mixed_precision:
946
+ be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
947
+
948
+ available_backends = [x for x in available_backends if x in be_selection]
949
+ if args.peak_calling:
950
+ if "jax" in available_backends:
951
+ available_backends.remove("jax")
952
+ if args.use_gpu and "pytorch" in available_backends:
953
+ available_backends = ("pytorch",)
954
+
955
+ # dim_match = len(template.shape) == len(target.shape) <= 3
956
+ # if dim_match and args.use_gpu and "jax" in available_backends:
957
+ # args.interpolation_order = 1
958
+ # available_backends = ["jax"]
959
+
960
+ backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
961
+ if args.use_gpu:
962
+ backend_preference = ("cupy", "pytorch", "jax")
963
+ for pref in backend_preference:
964
+ if pref not in available_backends:
965
+ continue
966
+ be.change_backend(pref)
967
+ if pref == "pytorch":
968
+ be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
969
+
970
+ if args.use_mixed_precision:
971
+ be.change_backend(
972
+ backend_name=pref,
973
+ default_dtype=be._array_backend.float16,
974
+ complex_dtype=be._array_backend.complex64,
975
+ default_dtype_int=be._array_backend.int16,
976
+ )
977
+ break
978
+
979
+ if pref == "pytorch" and args.interpolation_order == 3:
980
+ warnings.warn(
981
+ "Pytorch does not support --interpolation_order 3, setting it to 1."
982
+ )
983
+ args.interpolation_order = 1
984
+
985
+ available_memory = be.get_available_memory() * be.device_count()
986
+ if args.memory is None:
987
+ args.memory = int(args.memory_scaling * available_memory)
988
+
989
+ callback_class = MaxScoreOverRotations
990
+ if args.peak_calling:
991
+ callback_class = PeakCallerMaximumFilter
992
+
993
+ matching_data = MatchingData(
994
+ target=target,
995
+ template=template.data,
996
+ target_mask=target_mask,
997
+ template_mask=template_mask,
998
+ invert_target=args.invert_target_contrast,
999
+ rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
1000
+ )
1001
+
1002
+ matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
1003
+ matching_data.template_filter, matching_data.target_filter = setup_filter(
1004
+ args, template, target
1005
+ )
1006
+
1007
+ matching_data.set_matching_dimension(
1008
+ target_dim=target.metadata.get("batch_dimension", None),
1009
+ template_dim=template.metadata.get("batch_dimension", None),
1010
+ )
1011
+ splits, schedule = compute_schedule(args, matching_data, callback_class)
1012
+
1013
+ n_splits = np.prod(list(splits.values()))
1014
+ target_split = ", ".join(
1015
+ [":".join([str(x) for x in axis]) for axis in splits.items()]
1016
+ )
1017
+ gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
1018
+ options = {
1019
+ "Angular Sampling": f"{args.angular_sampling}"
1020
+ f" [{matching_data.rotations.shape[0]} rotations]",
1021
+ "Center Template": not args.no_centering,
1022
+ "Scramble Template": args.scramble_phases,
1023
+ "Invert Contrast": args.invert_target_contrast,
1024
+ "Extend Target Edges": args.pad_edges,
1025
+ "Interpolation Order": args.interpolation_order,
1026
+ "Setup Function": f"{get_func_fullname(matching_setup)}",
1027
+ "Scoring Function": f"{get_func_fullname(matching_score)}",
1028
+ }
1029
+
1030
+ print_block(
1031
+ name="Template Matching",
1032
+ data=options,
1033
+ label_width=max(len(key) for key in options.keys()) + 3,
1034
+ )
1035
+
1036
+ compute_options = {
1037
+ "Backend": be._BACKEND_REGISTRY[be._backend_name],
1038
+ "Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
1039
+ "Use Mixed Precision": args.use_mixed_precision,
1040
+ "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1041
+ "Temporary Directory": args.temp_directory,
1042
+ "Target Splits": f"{target_split} [N={n_splits}]",
1043
+ }
1044
+ print_block(
1045
+ name="Computation",
1046
+ data=compute_options,
1047
+ label_width=max(len(key) for key in options.keys()) + 3,
1048
+ )
1049
+
1050
+ filter_args = {
1051
+ "Lowpass": args.lowpass,
1052
+ "Highpass": args.highpass,
1053
+ "Smooth Pass": args.no_pass_smooth,
1054
+ "Pass Format": args.pass_format,
1055
+ "Spectral Whitening": args.whiten_spectrum,
1056
+ "Wedge Axes": args.wedge_axes,
1057
+ "Tilt Angles": args.tilt_angles,
1058
+ "Tilt Weighting": args.tilt_weighting,
1059
+ "Reconstruction Filter": args.reconstruction_filter,
1060
+ "Extend Filter Grid": args.pad_filter,
1061
+ }
1062
+ if args.ctf_file is not None or args.defocus is not None:
1063
+ filter_args["CTF File"] = args.ctf_file
1064
+ filter_args["Defocus"] = args.defocus
1065
+ filter_args["Phase Shift"] = args.phase_shift
1066
+ filter_args["Flip Phase"] = args.no_flip_phase
1067
+ filter_args["Acceleration Voltage"] = args.acceleration_voltage
1068
+ filter_args["Spherical Aberration"] = args.spherical_aberration
1069
+ filter_args["Amplitude Contrast"] = args.amplitude_contrast
1070
+ filter_args["Correct Defocus"] = args.correct_defocus_gradient
1071
+
1072
+ filter_args = {k: v for k, v in filter_args.items() if v is not None}
1073
+ if len(filter_args):
1074
+ print_block(
1075
+ name="Filters",
1076
+ data=filter_args,
1077
+ label_width=max(len(key) for key in options.keys()) + 3,
1078
+ )
1079
+
1080
+ analyzer_args = {
1081
+ "score_threshold": args.score_threshold,
1082
+ "number_of_peaks": args.number_of_peaks,
1083
+ "min_distance": max(template.shape) // 3,
1084
+ "use_memmap": args.use_memmap,
1085
+ }
1086
+ print_block(
1087
+ name="Analyzer",
1088
+ data={"Analyzer": callback_class, **analyzer_args},
1089
+ label_width=max(len(key) for key in options.keys()) + 3,
1090
+ )
1091
+ print("\n" + "-" * 80)
1092
+
1093
+ outer_jobs = f"{schedule[0]} job{'s' if schedule[0] > 1 else ''}"
1094
+ inner_jobs = f"{schedule[1]} core{'s' if schedule[1] > 1 else ''}"
1095
+ n_splits = f"{n_splits} split{'s' if n_splits > 1 else ''}"
1096
+ print(f"\nDistributing {n_splits} on {outer_jobs} each using {inner_jobs}.")
1097
+
1098
+ start = time()
1099
+ print("Running Template Matching. This might take a while ...")
1100
+ candidates = scan_subsets(
1101
+ matching_data=matching_data,
1102
+ job_schedule=schedule,
1103
+ matching_score=matching_score,
1104
+ matching_setup=matching_setup,
1105
+ callback_class=callback_class,
1106
+ callback_class_args=analyzer_args,
1107
+ target_splits=splits,
1108
+ pad_target_edges=args.pad_edges,
1109
+ pad_template_filter=args.pad_filter,
1110
+ interpolation_order=args.interpolation_order,
1111
+ )
1112
+
1113
+ candidates = list(candidates) if candidates is not None else []
1114
+ if issubclass(callback_class, MaxScoreOverRotations):
1115
+ if target_mask is not None and args.score != "MCC":
1116
+ candidates[0] *= target_mask.data
1117
+ with warnings.catch_warnings():
1118
+ warnings.simplefilter("ignore", category=UserWarning)
1119
+ nbytes = be.datatype_bytes(be._float_dtype)
1120
+ dtype = np.float32 if nbytes == 4 else np.float16
1121
+ rot_dim = matching_data.rotations.shape[1]
1122
+ candidates[3] = {
1123
+ x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1124
+ for i, x in candidates[3].items()
1125
+ }
1126
+ candidates.append((target.origin, template.origin, template.sampling_rate, args))
1127
+ write_pickle(data=candidates, filename=args.output)
1128
+
1129
+ runtime = time() - start
1130
+ print("\n" + "-" * 80)
1131
+ print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
1132
+
1133
+
1134
+ if __name__ == "__main__":
1135
+ main()