pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
scripts/match_template.py CHANGED
@@ -8,19 +8,19 @@
8
8
  import os
9
9
  import argparse
10
10
  import warnings
11
- import importlib.util
12
11
  from sys import exit
13
12
  from time import time
13
+ from typing import Tuple
14
14
  from copy import deepcopy
15
- from os.path import abspath
15
+ from os.path import abspath, exists
16
16
 
17
17
  import numpy as np
18
18
 
19
- from tme import Density, Preprocessor, __version__
19
+ from tme import Density, __version__
20
20
  from tme.matching_utils import (
21
21
  get_rotation_matrices,
22
+ get_rotations_around_vector,
22
23
  compute_parallelization_schedule,
23
- euler_from_rotationmatrix,
24
24
  scramble_phases,
25
25
  generate_tempfile_name,
26
26
  write_pickle,
@@ -31,7 +31,8 @@ from tme.analyzer import (
31
31
  MaxScoreOverRotations,
32
32
  PeakCallerMaximumFilter,
33
33
  )
34
- from tme.backends import backend
34
+ from tme.backends import backend as be
35
+ from tme.preprocessing import Compose
35
36
 
36
37
 
37
38
  def get_func_fullname(func) -> str:
@@ -49,7 +50,7 @@ def print_block(name: str, data: dict, label_width=20) -> None:
49
50
 
50
51
  def print_entry() -> None:
51
52
  width = 80
52
- text = f" pyTME v{__version__} "
53
+ text = f" pytme v{__version__} "
53
54
  padding_total = width - len(text) - 2
54
55
  padding_left = padding_total // 2
55
56
  padding_right = padding_total - padding_left
@@ -150,8 +151,200 @@ def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
150
151
  return True
151
152
 
152
153
 
154
+ def parse_rotation_logic(args, ndim):
155
+ if args.angular_sampling is not None:
156
+ rotations = get_rotation_matrices(
157
+ angular_sampling=args.angular_sampling,
158
+ dim=ndim,
159
+ use_optimized_set=not args.no_use_optimized_set,
160
+ )
161
+ if args.angular_sampling >= 180:
162
+ rotations = np.eye(ndim).reshape(1, ndim, ndim)
163
+ return rotations
164
+
165
+ if args.axis_sampling is None:
166
+ args.axis_sampling = args.cone_sampling
167
+
168
+ rotations = get_rotations_around_vector(
169
+ cone_angle=args.cone_angle,
170
+ cone_sampling=args.cone_sampling,
171
+ axis_angle=args.axis_angle,
172
+ axis_sampling=args.axis_sampling,
173
+ n_symmetry=args.axis_symmetry,
174
+ )
175
+ return rotations
176
+
177
+
178
+ # TODO: Think about whether wedge mask should also be added to target
179
+ # For now leave it at the cost of incorrect upper bound on the scores
180
+ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
181
+ from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
182
+ from tme.preprocessing.tilt_series import (
183
+ Wedge,
184
+ WedgeReconstructed,
185
+ ReconstructFromTilt,
186
+ )
187
+
188
+ template_filter, target_filter = [], []
189
+ if args.tilt_angles is not None:
190
+ try:
191
+ wedge = Wedge.from_file(args.tilt_angles)
192
+ wedge.weight_type = args.tilt_weighting
193
+ if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
194
+ wedge = WedgeReconstructed(
195
+ angles=wedge.angles, weight_wedge=args.tilt_weighting == "angle"
196
+ )
197
+ except FileNotFoundError:
198
+ tilt_step, create_continuous_wedge = None, True
199
+ tilt_start, tilt_stop = args.tilt_angles.split(",")
200
+ if ":" in tilt_stop:
201
+ create_continuous_wedge = False
202
+ tilt_stop, tilt_step = tilt_stop.split(":")
203
+ tilt_start, tilt_stop = float(tilt_start), float(tilt_stop)
204
+ tilt_angles = (tilt_start, tilt_stop)
205
+ if tilt_step is not None:
206
+ tilt_step = float(tilt_step)
207
+ tilt_angles = np.arange(
208
+ -tilt_start, tilt_stop + tilt_step, tilt_step
209
+ ).tolist()
210
+
211
+ if args.tilt_weighting is not None and tilt_step is None:
212
+ raise ValueError(
213
+ "Tilt weighting is not supported for continuous wedges."
214
+ )
215
+ if args.tilt_weighting not in ("angle", None):
216
+ raise ValueError(
217
+ "Tilt weighting schemes other than 'angle' or 'None' require "
218
+ "a specification of electron doses via --tilt_angles."
219
+ )
220
+
221
+ wedge = Wedge(
222
+ angles=tilt_angles,
223
+ opening_axis=args.wedge_axes[0],
224
+ tilt_axis=args.wedge_axes[1],
225
+ shape=None,
226
+ weight_type=None,
227
+ weights=np.ones_like(tilt_angles),
228
+ )
229
+ if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
230
+ wedge = WedgeReconstructed(
231
+ angles=tilt_angles,
232
+ weight_wedge=args.tilt_weighting == "angle",
233
+ create_continuous_wedge=create_continuous_wedge,
234
+ )
235
+
236
+ wedge.opening_axis = args.wedge_axes[0]
237
+ wedge.tilt_axis = args.wedge_axes[1]
238
+ wedge.sampling_rate = template.sampling_rate
239
+ template_filter.append(wedge)
240
+ if not isinstance(wedge, WedgeReconstructed):
241
+ template_filter.append(
242
+ ReconstructFromTilt(
243
+ reconstruction_filter=args.reconstruction_filter,
244
+ interpolation_order=args.reconstruction_interpolation_order,
245
+ )
246
+ )
247
+
248
+ if args.ctf_file is not None or args.defocus is not None:
249
+ from tme.preprocessing.tilt_series import CTF
250
+
251
+ needs_reconstruction = True
252
+ if args.ctf_file is not None:
253
+ ctf = CTF.from_file(args.ctf_file)
254
+ n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
255
+ if n_tilts_ctfs != n_tils_angles:
256
+ raise ValueError(
257
+ f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
258
+ f"recieved {n_tils_angles} tilt angles. Expected one angle "
259
+ "per micrograph."
260
+ )
261
+ ctf.angles = wedge.angles
262
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
263
+ else:
264
+ needs_reconstruction = False
265
+ ctf = CTF(
266
+ defocus_x=[args.defocus],
267
+ phase_shift=[args.phase_shift],
268
+ defocus_y=None,
269
+ angles=[0],
270
+ shape=None,
271
+ return_real_fourier=True,
272
+ )
273
+ ctf.sampling_rate = template.sampling_rate
274
+ ctf.flip_phase = args.no_flip_phase
275
+ ctf.amplitude_contrast = args.amplitude_contrast
276
+ ctf.spherical_aberration = args.spherical_aberration
277
+ ctf.acceleration_voltage = args.acceleration_voltage * 1e3
278
+ ctf.correct_defocus_gradient = args.correct_defocus_gradient
279
+
280
+ if not needs_reconstruction:
281
+ template_filter.append(ctf)
282
+ elif isinstance(template_filter[-1], ReconstructFromTilt):
283
+ template_filter.insert(-1, ctf)
284
+ else:
285
+ template_filter.insert(0, ctf)
286
+ template_filter.insert(
287
+ 1,
288
+ ReconstructFromTilt(
289
+ reconstruction_filter=args.reconstruction_filter,
290
+ interpolation_order=args.reconstruction_interpolation_order,
291
+ ),
292
+ )
293
+
294
+ if args.lowpass or args.highpass is not None:
295
+ lowpass, highpass = args.lowpass, args.highpass
296
+ if args.pass_format == "voxel":
297
+ if lowpass is not None:
298
+ lowpass = np.max(np.multiply(lowpass, template.sampling_rate))
299
+ if highpass is not None:
300
+ highpass = np.max(np.multiply(highpass, template.sampling_rate))
301
+ elif args.pass_format == "frequency":
302
+ if lowpass is not None:
303
+ lowpass = np.max(np.divide(template.sampling_rate, lowpass))
304
+ if highpass is not None:
305
+ highpass = np.max(np.divide(template.sampling_rate, highpass))
306
+
307
+ try:
308
+ if args.lowpass >= args.highpass:
309
+ warnings.warn("--lowpass should be smaller than --highpass.")
310
+ except Exception:
311
+ pass
312
+
313
+ bandpass = BandPassFilter(
314
+ use_gaussian=args.no_pass_smooth,
315
+ lowpass=lowpass,
316
+ highpass=highpass,
317
+ sampling_rate=template.sampling_rate,
318
+ )
319
+ template_filter.append(bandpass)
320
+
321
+ if not args.no_filter_target:
322
+ target_filter.append(bandpass)
323
+
324
+ if args.whiten_spectrum:
325
+ whitening_filter = LinearWhiteningFilter()
326
+ template_filter.append(whitening_filter)
327
+ target_filter.append(whitening_filter)
328
+
329
+ needs_reconstruction = any(
330
+ [isinstance(t, ReconstructFromTilt) for t in template_filter]
331
+ )
332
+ if needs_reconstruction and args.reconstruction_filter is None:
333
+ warnings.warn(
334
+ "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
335
+ )
336
+
337
+ template_filter = Compose(template_filter) if len(template_filter) else None
338
+ target_filter = Compose(target_filter) if len(target_filter) else None
339
+
340
+ return template_filter, target_filter
341
+
342
+
153
343
  def parse_args():
154
- parser = argparse.ArgumentParser(description="Perform template matching.")
344
+ parser = argparse.ArgumentParser(
345
+ description="Perform template matching.",
346
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
347
+ )
155
348
 
156
349
  io_group = parser.add_argument_group("Input / Output")
157
350
  io_group.add_argument(
@@ -221,22 +414,65 @@ def parse_args():
221
414
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
222
415
  help="Template matching scoring function.",
223
416
  )
224
- scoring_group.add_argument(
225
- "-p",
226
- dest="peak_calling",
227
- action="store_true",
228
- default=False,
229
- help="Perform peak calling instead of score aggregation.",
230
- )
231
- scoring_group.add_argument(
417
+
418
+ angular_group = parser.add_argument_group("Angular Sampling")
419
+ angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
420
+
421
+ angular_exclusive.add_argument(
232
422
  "-a",
233
423
  dest="angular_sampling",
234
424
  type=check_positive,
235
- default=40.0,
236
- help="Angular sampling rate for template matching. "
425
+ default=None,
426
+ help="Angular sampling rate using optimized rotational sets."
237
427
  "A lower number yields more rotations. Values >= 180 sample only the identity.",
238
428
  )
239
-
429
+ angular_exclusive.add_argument(
430
+ "--cone_angle",
431
+ dest="cone_angle",
432
+ type=check_positive,
433
+ default=None,
434
+ help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
435
+ "narrow interval around a known orientation, e.g. for surface oversampling.",
436
+ )
437
+ angular_group.add_argument(
438
+ "--cone_sampling",
439
+ dest="cone_sampling",
440
+ type=check_positive,
441
+ default=None,
442
+ help="Sampling rate of the cone in degrees.",
443
+ )
444
+ angular_group.add_argument(
445
+ "--axis_angle",
446
+ dest="axis_angle",
447
+ type=check_positive,
448
+ default=360.0,
449
+ required=False,
450
+ help="Sampling angle along the z-axis of the cone.",
451
+ )
452
+ angular_group.add_argument(
453
+ "--axis_sampling",
454
+ dest="axis_sampling",
455
+ type=check_positive,
456
+ default=None,
457
+ required=False,
458
+ help="Sampling rate along the z-axis of the cone. Defaults to --cone_sampling.",
459
+ )
460
+ angular_group.add_argument(
461
+ "--axis_symmetry",
462
+ dest="axis_symmetry",
463
+ type=check_positive,
464
+ default=1,
465
+ required=False,
466
+ help="N-fold symmetry around z-axis of the cone.",
467
+ )
468
+ angular_group.add_argument(
469
+ "--no_use_optimized_set",
470
+ dest="no_use_optimized_set",
471
+ action="store_true",
472
+ default=False,
473
+ required=False,
474
+ help="Whether to use random uniform instead of optimized rotation sets.",
475
+ )
240
476
 
241
477
  computation_group = parser.add_argument_group("Computation")
242
478
  computation_group.add_argument(
@@ -279,23 +515,7 @@ def parse_args():
279
515
  required=False,
280
516
  type=float,
281
517
  default=0.85,
282
- help="Fraction of available memory that can be used. Defaults to 0.85 and is "
283
- "ignored if --ram is set",
284
- )
285
- computation_group.add_argument(
286
- "--use_mixed_precision",
287
- dest="use_mixed_precision",
288
- action="store_true",
289
- default=False,
290
- help="Use float16 for real values operations where possible.",
291
- )
292
- computation_group.add_argument(
293
- "--use_memmap",
294
- dest="use_memmap",
295
- action="store_true",
296
- default=False,
297
- help="Use memmaps to offload large data objects to disk. "
298
- "Particularly useful for large inputs in combination with --use_gpu.",
518
+ help="Fraction of available memory to be used. Ignored if --ram is set.",
299
519
  )
300
520
  computation_group.add_argument(
301
521
  "--temp_directory",
@@ -303,64 +523,176 @@ def parse_args():
303
523
  default=None,
304
524
  help="Directory for temporary objects. Faster I/O improves runtime.",
305
525
  )
306
-
526
+ computation_group.add_argument(
527
+ "--backend",
528
+ dest="backend",
529
+ default=None,
530
+ choices=be.available_backends(),
531
+ help="[Expert] Overwrite default computation backend.",
532
+ )
307
533
  filter_group = parser.add_argument_group("Filters")
308
534
  filter_group.add_argument(
309
- "--gaussian_sigma",
310
- dest="gaussian_sigma",
535
+ "--lowpass",
536
+ dest="lowpass",
311
537
  type=float,
312
538
  required=False,
313
- help="Sigma parameter for Gaussian filtering the template.",
539
+ help="Resolution to lowpass filter template and target to in the same unit "
540
+ "as the sampling rate of template and target (typically Ångstrom).",
541
+ )
542
+ filter_group.add_argument(
543
+ "--highpass",
544
+ dest="highpass",
545
+ type=float,
546
+ required=False,
547
+ help="Resolution to highpass filter template and target to in the same unit "
548
+ "as the sampling rate of template and target (typically Ångstrom).",
549
+ )
550
+ filter_group.add_argument(
551
+ "--no_pass_smooth",
552
+ dest="no_pass_smooth",
553
+ action="store_false",
554
+ default=True,
555
+ help="Whether a hard edge filter should be used for --lowpass and --highpass.",
314
556
  )
315
557
  filter_group.add_argument(
316
- "--bandpass_band",
317
- dest="bandpass_band",
558
+ "--pass_format",
559
+ dest="pass_format",
318
560
  type=str,
319
561
  required=False,
320
- help="Comma separated start and stop frequency for bandpass filtering the"
321
- " template, e.g. 0.1, 0.5",
562
+ default="sampling_rate",
563
+ choices=["sampling_rate", "voxel", "frequency"],
564
+ help="How values passed to --lowpass and --highpass should be interpreted. ",
322
565
  )
323
566
  filter_group.add_argument(
324
- "--bandpass_smooth",
325
- dest="bandpass_smooth",
326
- type=float,
567
+ "--whiten_spectrum",
568
+ dest="whiten_spectrum",
569
+ action="store_true",
570
+ default=None,
571
+ help="Apply spectral whitening to template and target based on target spectrum.",
572
+ )
573
+ filter_group.add_argument(
574
+ "--wedge_axes",
575
+ dest="wedge_axes",
576
+ type=str,
327
577
  required=False,
328
578
  default=None,
329
- help="Sigma smooth parameter for the bandpass filter.",
579
+ help="Indices of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open "
580
+ "in z-direction and tilted over the x axis.",
330
581
  )
331
582
  filter_group.add_argument(
332
- "--tilt_range",
333
- dest="tilt_range",
583
+ "--tilt_angles",
584
+ dest="tilt_angles",
334
585
  type=str,
335
586
  required=False,
336
- help="Comma separated start and stop stage tilt angle, e.g. '50,45'. Used"
337
- " to create a wedge mask to be applied to the template.",
587
+ default=None,
588
+ help="Path to a tab-separated file containing the column angles and optionally "
589
+ " weights, or comma separated start and stop stage tilt angle, e.g. 50,45, which "
590
+ " yields a continuous wedge mask. Alternatively, a tilt step size can be "
591
+ "specified like 50,45:5.0 to sample 5.0 degree tilt angle steps.",
338
592
  )
339
593
  filter_group.add_argument(
340
- "--tilt_step",
341
- dest="tilt_step",
342
- type=float,
594
+ "--tilt_weighting",
595
+ dest="tilt_weighting",
596
+ type=str,
343
597
  required=False,
598
+ choices=["angle", "relion", "grigorieff"],
344
599
  default=None,
345
- help="Step size between tilts. e.g. '5'. When set the wedge mask"
346
- " reflects the individual tilts, otherwise a continuous mask is used.",
600
+ help="Weighting scheme used to reweight individual tilts. Available options: "
601
+ "angle (cosine based weighting), "
602
+ "relion (relion formalism for wedge weighting) requires,"
603
+ "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
604
+ "relion and grigorieff require electron doses in --tilt_angles weights column.",
347
605
  )
348
606
  filter_group.add_argument(
349
- "--wedge_axes",
350
- dest="wedge_axes",
607
+ "--reconstruction_filter",
608
+ dest="reconstruction_filter",
351
609
  type=str,
352
610
  required=False,
353
- default="0,2",
354
- help="Axis index of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open in"
355
- " z and tilted over x.",
611
+ choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
612
+ default=None,
613
+ help="Filter applied when reconstructing (N+1)-D from N-D filters.",
356
614
  )
357
615
  filter_group.add_argument(
358
- "--wedge_smooth",
359
- dest="wedge_smooth",
616
+ "--reconstruction_interpolation_order",
617
+ dest="reconstruction_interpolation_order",
618
+ type=int,
619
+ default=1,
620
+ required=False,
621
+ help="Analogous to --interpolation_order but for reconstruction.",
622
+ )
623
+ filter_group.add_argument(
624
+ "--no_filter_target",
625
+ dest="no_filter_target",
626
+ action="store_true",
627
+ default=False,
628
+ help="Whether to not apply potential filters to the target.",
629
+ )
630
+
631
+ ctf_group = parser.add_argument_group("Contrast Transfer Function")
632
+ ctf_group.add_argument(
633
+ "--ctf_file",
634
+ dest="ctf_file",
635
+ type=str,
636
+ required=False,
637
+ default=None,
638
+ help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
639
+ "interpreted as tilt obtained at the angle specified in --tilt_angles. ",
640
+ )
641
+ ctf_group.add_argument(
642
+ "--defocus",
643
+ dest="defocus",
360
644
  type=float,
361
645
  required=False,
362
646
  default=None,
363
- help="Sigma smooth parameter for the wedge mask.",
647
+ help="Defocus in units of sampling rate (typically Ångstrom). "
648
+ "Superseded by --ctf_file.",
649
+ )
650
+ ctf_group.add_argument(
651
+ "--phase_shift",
652
+ dest="phase_shift",
653
+ type=float,
654
+ required=False,
655
+ default=0,
656
+ help="Phase shift in degrees. Superseded by --ctf_file.",
657
+ )
658
+ ctf_group.add_argument(
659
+ "--acceleration_voltage",
660
+ dest="acceleration_voltage",
661
+ type=float,
662
+ required=False,
663
+ default=300,
664
+ help="Acceleration voltage in kV.",
665
+ )
666
+ ctf_group.add_argument(
667
+ "--spherical_aberration",
668
+ dest="spherical_aberration",
669
+ type=float,
670
+ required=False,
671
+ default=2.7e7,
672
+ help="Spherical aberration in units of sampling rate (typically Ångstrom).",
673
+ )
674
+ ctf_group.add_argument(
675
+ "--amplitude_contrast",
676
+ dest="amplitude_contrast",
677
+ type=float,
678
+ required=False,
679
+ default=0.07,
680
+ help="Amplitude contrast.",
681
+ )
682
+ ctf_group.add_argument(
683
+ "--no_flip_phase",
684
+ dest="no_flip_phase",
685
+ action="store_false",
686
+ required=False,
687
+ help="Do not perform phase-flipping CTF correction.",
688
+ )
689
+ ctf_group.add_argument(
690
+ "--correct_defocus_gradient",
691
+ dest="correct_defocus_gradient",
692
+ action="store_true",
693
+ required=False,
694
+ help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
695
+ "defocus gradients.",
364
696
  )
365
697
 
366
698
  performance_group = parser.add_argument_group("Performance")
@@ -403,14 +735,37 @@ def parse_args():
403
735
  "for numerical stability. When working with very large targets, e.g. tomograms, "
404
736
  "it is safe to use this flag and benefit from the performance gain.",
405
737
  )
738
+ performance_group.add_argument(
739
+ "--no_filter_padding",
740
+ dest="no_filter_padding",
741
+ action="store_true",
742
+ default=False,
743
+ help="Omits padding of optional template filters. Particularly effective when "
744
+ "the target is much larger than the template. However, for fast osciliating "
745
+ "filters setting this flag can introduce aliasing effects.",
746
+ )
406
747
  performance_group.add_argument(
407
748
  "--interpolation_order",
408
749
  dest="interpolation_order",
409
750
  required=False,
410
751
  type=int,
411
752
  default=3,
412
- help="Spline interpolation used for template rotations. If less than zero "
413
- "no interpolation is performed.",
753
+ help="Spline interpolation used for rotations.",
754
+ )
755
+ performance_group.add_argument(
756
+ "--use_mixed_precision",
757
+ dest="use_mixed_precision",
758
+ action="store_true",
759
+ default=False,
760
+ help="Use float16 for real values operations where possible.",
761
+ )
762
+ performance_group.add_argument(
763
+ "--use_memmap",
764
+ dest="use_memmap",
765
+ action="store_true",
766
+ default=False,
767
+ help="Use memmaps to offload large data objects to disk. "
768
+ "Particularly useful for large inputs in combination with --use_gpu.",
414
769
  )
415
770
 
416
771
  analyzer_group = parser.add_argument_group("Analyzer")
@@ -422,8 +777,22 @@ def parse_args():
422
777
  default=0,
423
778
  help="Minimum template matching scores to consider for analysis.",
424
779
  )
425
-
780
+ analyzer_group.add_argument(
781
+ "-p",
782
+ dest="peak_calling",
783
+ action="store_true",
784
+ default=False,
785
+ help="Perform peak calling instead of score aggregation.",
786
+ )
787
+ analyzer_group.add_argument(
788
+ "--number_of_peaks",
789
+ dest="number_of_peaks",
790
+ action="store_true",
791
+ default=1000,
792
+ help="Number of peaks to call, 1000 by default.",
793
+ )
426
794
  args = parser.parse_args()
795
+ args.version = __version__
427
796
 
428
797
  if args.interpolation_order < 0:
429
798
  args.interpolation_order = None
@@ -460,6 +829,21 @@ def parse_args():
460
829
  int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
461
830
  ]
462
831
 
832
+ if args.tilt_angles is not None:
833
+ if args.wedge_axes is None:
834
+ raise ValueError("Need to specify --wedge_axes when --tilt_angles is set.")
835
+ if not exists(args.tilt_angles):
836
+ try:
837
+ float(args.tilt_angles.split(",")[0])
838
+ except ValueError:
839
+ raise ValueError(f"{args.tilt_angles} is not a file nor a range.")
840
+
841
+ if args.ctf_file is not None and args.tilt_angles is None:
842
+ raise ValueError("Need to specify --tilt_angles when --ctf_file is set.")
843
+
844
+ if args.wedge_axes is not None:
845
+ args.wedge_axes = tuple(int(i) for i in args.wedge_axes.split(","))
846
+
463
847
  return args
464
848
 
465
849
 
@@ -477,12 +861,13 @@ def main():
477
861
  sampling_rate=target.sampling_rate,
478
862
  )
479
863
 
480
- if not np.allclose(target.sampling_rate, template.sampling_rate):
481
- print(
482
- f"Resampling template to {target.sampling_rate}. "
483
- "Consider providing a template with the same sampling rate as the target."
484
- )
485
- template = template.resample(target.sampling_rate, order=3)
864
+ if target.sampling_rate.size == template.sampling_rate.size:
865
+ if not np.allclose(target.sampling_rate, template.sampling_rate):
866
+ print(
867
+ f"Resampling template to {target.sampling_rate}. "
868
+ "Consider providing a template with the same sampling rate as the target."
869
+ )
870
+ template = template.resample(target.sampling_rate, order=3)
486
871
 
487
872
  template_mask = load_and_validate_mask(
488
873
  mask_target=template, mask_path=args.template_mask
@@ -536,51 +921,6 @@ def main():
536
921
  },
537
922
  )
538
923
 
539
- template_filter = {}
540
- if args.gaussian_sigma is not None:
541
- template.data = Preprocessor().gaussian_filter(
542
- sigma=args.gaussian_sigma, template=template.data
543
- )
544
-
545
- if args.bandpass_band is not None:
546
- bandpass_start, bandpass_stop = [
547
- float(x) for x in args.bandpass_band.split(",")
548
- ]
549
- if args.bandpass_smooth is None:
550
- args.bandpass_smooth = 0
551
-
552
- template_filter["bandpass_mask"] = {
553
- "minimum_frequency": bandpass_start,
554
- "maximum_frequency": bandpass_stop,
555
- "gaussian_sigma": args.bandpass_smooth,
556
- }
557
-
558
- if args.tilt_range is not None:
559
- args.wedge_smooth if args.wedge_smooth is not None else 0
560
- tilt_start, tilt_stop = [float(x) for x in args.tilt_range.split(",")]
561
- opening_axis, tilt_axis = [int(x) for x in args.wedge_axes.split(",")]
562
-
563
- if args.tilt_step is not None:
564
- template_filter["step_wedge_mask"] = {
565
- "start_tilt": tilt_start,
566
- "stop_tilt": tilt_stop,
567
- "tilt_step": args.tilt_step,
568
- "sigma": args.wedge_smooth,
569
- "opening_axis": opening_axis,
570
- "tilt_axis": tilt_axis,
571
- "omit_negative_frequencies": True,
572
- }
573
- else:
574
- template_filter["continuous_wedge_mask"] = {
575
- "start_tilt": tilt_start,
576
- "stop_tilt": tilt_stop,
577
- "tilt_axis": tilt_axis,
578
- "opening_axis": opening_axis,
579
- "infinite_plane": True,
580
- "sigma": args.wedge_smooth,
581
- "omit_negative_frequencies": True,
582
- }
583
-
584
924
  if template_mask is None:
585
925
  template_mask = template.empty
586
926
  if not args.no_centering:
@@ -619,72 +959,97 @@ def main():
619
959
  template.data, noise_proportion=1.0, normalize_power=True
620
960
  )
621
961
 
622
- available_memory = backend.get_available_memory()
962
+ # Determine suitable backend for the selected operation
963
+ available_backends = be.available_backends()
964
+ if args.backend is not None:
965
+ req_backend = args.backend
966
+ if req_backend not in available_backends:
967
+ raise ValueError("Requested backend is not available.")
968
+ available_backends = [req_backend]
969
+
970
+ be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
623
971
  if args.use_gpu:
624
972
  args.cores = len(args.gpu_indices)
625
- has_torch = importlib.util.find_spec("torch") is not None
626
- has_cupy = importlib.util.find_spec("cupy") is not None
973
+ be_selection = ("pytorch", "cupy", "jax")
974
+ if args.use_mixed_precision:
975
+ be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
627
976
 
628
- if not has_torch and not has_cupy:
629
- raise ValueError(
630
- "Found neither CuPy nor PyTorch installation. You need to install"
631
- " either to enable GPU support."
632
- )
633
-
634
- if args.peak_calling:
635
- preferred_backend = "pytorch"
636
- if not has_torch:
637
- preferred_backend = "cupy"
638
- backend.change_backend(backend_name=preferred_backend, device="cuda")
639
- else:
640
- preferred_backend = "cupy"
641
- if not has_cupy:
642
- preferred_backend = "pytorch"
643
- backend.change_backend(backend_name=preferred_backend, device="cuda")
644
- if args.use_mixed_precision and preferred_backend == "pytorch":
977
+ available_backends = [x for x in available_backends if x in be_selection]
978
+ if args.peak_calling:
979
+ if "jax" in available_backends:
980
+ available_backends.remove("jax")
981
+ if args.use_gpu and "pytorch" in available_backends:
982
+ available_backends = ("pytorch",)
983
+ if args.interpolation_order == 3:
645
984
  raise NotImplementedError(
646
- "pytorch backend does not yet support mixed precision."
647
- " Consider installing CuPy to enable this feature."
648
- )
649
- elif args.use_mixed_precision:
650
- backend.change_backend(
651
- backend_name="cupy",
652
- default_dtype=backend._array_backend.float16,
653
- complex_dtype=backend._array_backend.complex64,
654
- default_dtype_int=backend._array_backend.int16,
985
+ "Pytorch does not support --interpolation_order 3, 1 is supported."
655
986
  )
656
- available_memory = backend.get_available_memory() * args.cores
657
- if preferred_backend == "pytorch" and args.interpolation_order == 3:
658
- args.interpolation_order = 1
987
+ # dim_match = len(template.shape) == len(target.shape) <= 3
988
+ # if dim_match and args.use_gpu and "jax" in available_backends:
989
+ # args.interpolation_order = 1
990
+ # available_backends = ["jax"]
659
991
 
992
+ backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
993
+ if args.use_gpu:
994
+ backend_preference = ("cupy", "pytorch", "jax")
995
+ for pref in backend_preference:
996
+ if pref not in available_backends:
997
+ continue
998
+ be.change_backend(pref)
999
+ if pref == "pytorch":
1000
+ be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
1001
+
1002
+ if args.use_mixed_precision:
1003
+ be.change_backend(
1004
+ backend_name=pref,
1005
+ default_dtype=be._array_backend.float16,
1006
+ complex_dtype=be._array_backend.complex64,
1007
+ default_dtype_int=be._array_backend.int16,
1008
+ )
1009
+ break
1010
+
1011
+ available_memory = be.get_available_memory() * be.device_count()
660
1012
  if args.memory is None:
661
1013
  args.memory = int(args.memory_scaling * available_memory)
662
1014
 
663
- target_padding = np.zeros_like(template.shape)
664
- if args.pad_target_edges:
665
- target_padding = template.shape
666
-
667
- template_box = template.shape
668
- if not args.pad_fourier:
669
- template_box = np.ones(len(template_box), dtype=int)
670
-
671
1015
  callback_class = MaxScoreOverRotations
672
1016
  if args.peak_calling:
673
1017
  callback_class = PeakCallerMaximumFilter
674
1018
 
1019
+ matching_data = MatchingData(
1020
+ target=target,
1021
+ template=template.data,
1022
+ target_mask=target_mask,
1023
+ template_mask=template_mask,
1024
+ invert_target=args.invert_target_contrast,
1025
+ rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
1026
+ )
1027
+
1028
+ matching_data.template_filter, matching_data.target_filter = setup_filter(
1029
+ args, template, target
1030
+ )
1031
+
1032
+ template_box = matching_data._output_template_shape
1033
+ if not args.pad_fourier:
1034
+ template_box = tuple(0 for _ in range(len(template_box)))
1035
+
1036
+ target_padding = tuple(0 for _ in range(len(template_box)))
1037
+ if args.pad_target_edges:
1038
+ target_padding = matching_data._output_template_shape
1039
+
675
1040
  splits, schedule = compute_parallelization_schedule(
676
1041
  shape1=target.shape,
677
- shape2=template_box,
678
- shape1_padding=target_padding,
1042
+ shape2=tuple(int(x) for x in template_box),
1043
+ shape1_padding=tuple(int(x) for x in target_padding),
679
1044
  max_cores=args.cores,
680
1045
  max_ram=args.memory,
681
1046
  split_only_outer=args.use_gpu,
682
1047
  matching_method=args.score,
683
1048
  analyzer_method=callback_class.__name__,
684
- backend=backend._backend_name,
685
- float_nbytes=backend.datatype_bytes(backend._default_dtype),
686
- complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
687
- integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
1049
+ backend=be._backend_name,
1050
+ float_nbytes=be.datatype_bytes(be._float_dtype),
1051
+ complex_nbytes=be.datatype_bytes(be._complex_dtype),
1052
+ integer_nbytes=be.datatype_bytes(be._int_dtype),
688
1053
  )
689
1054
 
690
1055
  if splits is None:
@@ -694,63 +1059,85 @@ def main():
694
1059
  )
695
1060
  exit(-1)
696
1061
 
697
- analyzer_args = {
698
- "score_threshold": args.score_threshold,
699
- "number_of_peaks": 1000,
700
- "convolution_mode": "valid",
701
- "use_memmap": args.use_memmap,
702
- }
703
-
704
1062
  matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
705
- matching_data = MatchingData(target=target, template=template.data)
706
- matching_data.rotations = get_rotation_matrices(
707
- angular_sampling=args.angular_sampling, dim=target.data.ndim
708
- )
709
- if args.angular_sampling >= 180:
710
- ndim = target.data.ndim
711
- matching_data.rotations = np.eye(ndim).reshape(1, ndim, ndim)
712
-
713
- matching_data.template_filter = template_filter
714
- matching_data._invert_target = args.invert_target_contrast
715
- if target_mask is not None:
716
- matching_data.target_mask = target_mask
717
- if template_mask is not None:
718
- matching_data.template_mask = template_mask.data
719
-
720
1063
  n_splits = np.prod(list(splits.values()))
721
1064
  target_split = ", ".join(
722
1065
  [":".join([str(x) for x in axis]) for axis in splits.items()]
723
1066
  )
724
1067
  gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
725
1068
  options = {
726
- "CPU Cores": args.cores,
727
- "Run on GPU": f"{args.use_gpu} [N={gpus_used}]",
728
- "Use Mixed Precision": args.use_mixed_precision,
729
- "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
730
- "Temporary Directory": args.temp_directory,
1069
+ "Angular Sampling": f"{args.angular_sampling}"
1070
+ f" [{matching_data.rotations.shape[0]} rotations]",
1071
+ "Center Template": not args.no_centering,
1072
+ "Scramble Template": args.scramble_phases,
1073
+ "Invert Contrast": args.invert_target_contrast,
731
1074
  "Extend Fourier Grid": not args.no_fourier_padding,
732
1075
  "Extend Target Edges": not args.no_edge_padding,
733
1076
  "Interpolation Order": args.interpolation_order,
734
- "Score": f"{args.score}",
735
1077
  "Setup Function": f"{get_func_fullname(matching_setup)}",
736
1078
  "Scoring Function": f"{get_func_fullname(matching_score)}",
737
- "Angular Sampling": f"{args.angular_sampling}"
738
- f" [{matching_data.rotations.shape[0]} rotations]",
739
- "Scramble Template": args.scramble_phases,
740
- "Target Splits": f"{target_split} [N={n_splits}]",
741
1079
  }
742
1080
 
743
1081
  print_block(
744
- name="Template Matching Options",
1082
+ name="Template Matching",
745
1083
  data=options,
746
- label_width=max(len(key) for key in options.keys()) + 2,
1084
+ label_width=max(len(key) for key in options.keys()) + 3,
747
1085
  )
748
1086
 
749
- options = {"Analyzer": callback_class, **analyzer_args}
1087
+ compute_options = {
1088
+ "Backend": be._BACKEND_REGISTRY[be._backend_name],
1089
+ "Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
1090
+ "Use Mixed Precision": args.use_mixed_precision,
1091
+ "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1092
+ "Temporary Directory": args.temp_directory,
1093
+ "Target Splits": f"{target_split} [N={n_splits}]",
1094
+ }
750
1095
  print_block(
751
- name="Score Analysis Options",
752
- data=options,
753
- label_width=max(len(key) for key in options.keys()) + 2,
1096
+ name="Computation",
1097
+ data=compute_options,
1098
+ label_width=max(len(key) for key in options.keys()) + 3,
1099
+ )
1100
+
1101
+ filter_args = {
1102
+ "Lowpass": args.lowpass,
1103
+ "Highpass": args.highpass,
1104
+ "Smooth Pass": args.no_pass_smooth,
1105
+ "Pass Format": args.pass_format,
1106
+ "Spectral Whitening": args.whiten_spectrum,
1107
+ "Wedge Axes": args.wedge_axes,
1108
+ "Tilt Angles": args.tilt_angles,
1109
+ "Tilt Weighting": args.tilt_weighting,
1110
+ "Reconstruction Filter": args.reconstruction_filter,
1111
+ }
1112
+ if args.ctf_file is not None or args.defocus is not None:
1113
+ filter_args["CTF File"] = args.ctf_file
1114
+ filter_args["Defocus"] = args.defocus
1115
+ filter_args["Phase Shift"] = args.phase_shift
1116
+ filter_args["Flip Phase"] = args.no_flip_phase
1117
+ filter_args["Acceleration Voltage"] = args.acceleration_voltage
1118
+ filter_args["Spherical Aberration"] = args.spherical_aberration
1119
+ filter_args["Amplitude Contrast"] = args.amplitude_contrast
1120
+ filter_args["Correct Defocus"] = args.correct_defocus_gradient
1121
+
1122
+ filter_args = {k: v for k, v in filter_args.items() if v is not None}
1123
+ if len(filter_args):
1124
+ print_block(
1125
+ name="Filters",
1126
+ data=filter_args,
1127
+ label_width=max(len(key) for key in options.keys()) + 3,
1128
+ )
1129
+
1130
+ analyzer_args = {
1131
+ "score_threshold": args.score_threshold,
1132
+ "number_of_peaks": args.number_of_peaks,
1133
+ "min_distance": max(template.shape) // 2,
1134
+ "min_boundary_distance": max(template.shape) // 2,
1135
+ "use_memmap": args.use_memmap,
1136
+ }
1137
+ print_block(
1138
+ name="Analyzer",
1139
+ data={"Analyzer": callback_class, **analyzer_args},
1140
+ label_width=max(len(key) for key in options.keys()) + 3,
754
1141
  )
755
1142
  print("\n" + "-" * 80)
756
1143
 
@@ -771,6 +1158,7 @@ def main():
771
1158
  target_splits=splits,
772
1159
  pad_target_edges=args.pad_target_edges,
773
1160
  pad_fourier=args.pad_fourier,
1161
+ pad_template_filter=not args.no_filter_padding,
774
1162
  interpolation_order=args.interpolation_order,
775
1163
  )
776
1164
 
@@ -780,19 +1168,18 @@ def main():
780
1168
  candidates[0] *= target_mask.data
781
1169
  with warnings.catch_warnings():
782
1170
  warnings.simplefilter("ignore", category=UserWarning)
1171
+ nbytes = be.datatype_bytes(be._float_dtype)
1172
+ dtype = np.float32 if nbytes == 4 else np.float16
1173
+ rot_dim = matching_data.rotations.shape[1]
783
1174
  candidates[3] = {
784
- x: euler_from_rotationmatrix(
785
- np.frombuffer(i, dtype=matching_data.rotations.dtype).reshape(
786
- candidates[0].ndim, candidates[0].ndim
787
- )
788
- )
1175
+ x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
789
1176
  for i, x in candidates[3].items()
790
1177
  }
791
-
792
- candidates.append((target.origin, template.origin, target.sampling_rate, args))
1178
+ candidates.append((target.origin, template.origin, template.sampling_rate, args))
793
1179
  write_pickle(data=candidates, filename=args.output)
794
1180
 
795
1181
  runtime = time() - start
1182
+ print("\n" + "-" * 80)
796
1183
  print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
797
1184
 
798
1185