pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  #!python
2
- """ CLI interface for basic pyTME template matching functions.
2
+ """ CLI for basic pyTME template matching functions.
3
3
 
4
4
  Copyright (c) 2023 European Molecular Biology Laboratory
5
5
 
@@ -8,7 +8,6 @@
8
8
  import os
9
9
  import argparse
10
10
  import warnings
11
- import importlib.util
12
11
  from sys import exit
13
12
  from time import time
14
13
  from typing import Tuple
@@ -22,9 +21,7 @@ from tme.matching_utils import (
22
21
  get_rotation_matrices,
23
22
  get_rotations_around_vector,
24
23
  compute_parallelization_schedule,
25
- euler_from_rotationmatrix,
26
24
  scramble_phases,
27
- generate_tempfile_name,
28
25
  write_pickle,
29
26
  )
30
27
  from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
@@ -33,7 +30,7 @@ from tme.analyzer import (
33
30
  MaxScoreOverRotations,
34
31
  PeakCallerMaximumFilter,
35
32
  )
36
- from tme.backends import backend
33
+ from tme.backends import backend as be
37
34
  from tme.preprocessing import Compose
38
35
 
39
36
 
@@ -52,7 +49,7 @@ def print_block(name: str, data: dict, label_width=20) -> None:
52
49
 
53
50
  def print_entry() -> None:
54
51
  width = 80
55
- text = f" pyTME v{__version__} "
52
+ text = f" pytme v{__version__} "
56
53
  padding_total = width - len(text) - 2
57
54
  padding_left = padding_total // 2
58
55
  padding_right = padding_total - padding_left
@@ -101,7 +98,9 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
101
98
  f"Expected shape of {mask_path} was {mask_target.shape},"
102
99
  f" got f{mask.shape}"
103
100
  )
104
- if not np.allclose(mask.sampling_rate, mask_target.sampling_rate):
101
+ if not np.allclose(
102
+ np.round(mask.sampling_rate, 2), np.round(mask_target.sampling_rate, 2)
103
+ ):
105
104
  raise ValueError(
106
105
  f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
107
106
  f", got f{mask.sampling_rate}"
@@ -109,50 +108,6 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
109
108
  return mask
110
109
 
111
110
 
112
- def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
113
- """
114
- Crop the provided data and mask to a smaller box based on a cutoff value.
115
-
116
- Parameters
117
- ----------
118
- data : Density
119
- The data that should be cropped.
120
- cutoff : float
121
- The threshold value to determine which parts of the data should be kept.
122
- data_mask : Density, optional
123
- A mask for the data that should be cropped.
124
-
125
- Returns
126
- -------
127
- bool
128
- Returns True if the data was adjusted (cropped), otherwise returns False.
129
-
130
- Notes
131
- -----
132
- Cropping is performed in place.
133
- """
134
- if cutoff is None:
135
- return False
136
-
137
- box = data.trim_box(cutoff=cutoff)
138
- box_mask = box
139
- if data_mask is not None:
140
- box_mask = data_mask.trim_box(cutoff=cutoff)
141
- box = tuple(
142
- slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
143
- for arr, mask in zip(box, box_mask)
144
- )
145
- if box == tuple(slice(0, x) for x in data.shape):
146
- return False
147
-
148
- data.adjust_box(box)
149
-
150
- if data_mask:
151
- data_mask.adjust_box(box)
152
-
153
- return True
154
-
155
-
156
111
  def parse_rotation_logic(args, ndim):
157
112
  if args.angular_sampling is not None:
158
113
  rotations = get_rotation_matrices(
@@ -177,8 +132,52 @@ def parse_rotation_logic(args, ndim):
177
132
  return rotations
178
133
 
179
134
 
180
- # TODO: Think about whether wedge mask should also be added to target
181
- # For now leave it at the cost of incorrect upper bound on the scores
135
+ def compute_schedule(
136
+ args,
137
+ target: Density,
138
+ matching_data: MatchingData,
139
+ callback_class,
140
+ pad_edges: bool = False,
141
+ ):
142
+ # User requested target padding
143
+ if args.pad_edges is True:
144
+ pad_edges = True
145
+ template_box = matching_data._output_template_shape
146
+ if not args.pad_fourier:
147
+ template_box = tuple(0 for _ in range(len(template_box)))
148
+
149
+ target_padding = tuple(0 for _ in range(len(template_box)))
150
+ if pad_edges:
151
+ target_padding = matching_data._output_template_shape
152
+
153
+ splits, schedule = compute_parallelization_schedule(
154
+ shape1=target.shape,
155
+ shape2=tuple(int(x) for x in template_box),
156
+ shape1_padding=tuple(int(x) for x in target_padding),
157
+ max_cores=args.cores,
158
+ max_ram=args.memory,
159
+ split_only_outer=args.use_gpu,
160
+ matching_method=args.score,
161
+ analyzer_method=callback_class.__name__,
162
+ backend=be._backend_name,
163
+ float_nbytes=be.datatype_bytes(be._float_dtype),
164
+ complex_nbytes=be.datatype_bytes(be._complex_dtype),
165
+ integer_nbytes=be.datatype_bytes(be._int_dtype),
166
+ )
167
+
168
+ if splits is None:
169
+ print(
170
+ "Found no suitable parallelization schedule. Consider increasing"
171
+ " available RAM or decreasing number of cores."
172
+ )
173
+ exit(-1)
174
+ n_splits = np.prod(list(splits.values()))
175
+ if pad_edges is False and n_splits > 1:
176
+ args.pad_edges = True
177
+ return compute_schedule(args, target, matching_data, callback_class, True)
178
+ return splits, schedule
179
+
180
+
182
181
  def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
183
182
  from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
184
183
  from tme.preprocessing.tilt_series import (
@@ -234,18 +233,23 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
234
233
  weight_wedge=args.tilt_weighting == "angle",
235
234
  create_continuous_wedge=create_continuous_wedge,
236
235
  )
236
+ wedge_target = WedgeReconstructed(
237
+ angles=tilt_angles,
238
+ weight_wedge=False,
239
+ create_continuous_wedge=create_continuous_wedge,
240
+ )
241
+ target_filter.append(wedge_target)
237
242
 
238
243
  wedge.opening_axis = args.wedge_axes[0]
239
244
  wedge.tilt_axis = args.wedge_axes[1]
240
245
  wedge.sampling_rate = template.sampling_rate
241
246
  template_filter.append(wedge)
242
247
  if not isinstance(wedge, WedgeReconstructed):
243
- template_filter.append(
244
- ReconstructFromTilt(
245
- reconstruction_filter=args.reconstruction_filter,
246
- interpolation_order=args.reconstruction_interpolation_order,
247
- )
248
+ reconstruction_filter = ReconstructFromTilt(
249
+ reconstruction_filter=args.reconstruction_filter,
250
+ interpolation_order=args.reconstruction_interpolation_order,
248
251
  )
252
+ template_filter.append(reconstruction_filter)
249
253
 
250
254
  if args.ctf_file is not None or args.defocus is not None:
251
255
  from tme.preprocessing.tilt_series import CTF
@@ -273,7 +277,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
273
277
  return_real_fourier=True,
274
278
  )
275
279
  ctf.sampling_rate = template.sampling_rate
276
- ctf.flip_phase = not args.no_flip_phase
280
+ ctf.flip_phase = args.no_flip_phase
277
281
  ctf.amplitude_contrast = args.amplitude_contrast
278
282
  ctf.spherical_aberration = args.spherical_aberration
279
283
  ctf.acceleration_voltage = args.acceleration_voltage * 1e3
@@ -306,6 +310,12 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
306
310
  if highpass is not None:
307
311
  highpass = np.max(np.divide(template.sampling_rate, highpass))
308
312
 
313
+ try:
314
+ if args.lowpass >= args.highpass:
315
+ warnings.warn("--lowpass should be smaller than --highpass.")
316
+ except Exception:
317
+ pass
318
+
309
319
  bandpass = BandPassFilter(
310
320
  use_gaussian=args.no_pass_smooth,
311
321
  lowpass=lowpass,
@@ -313,7 +323,9 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
313
323
  sampling_rate=template.sampling_rate,
314
324
  )
315
325
  template_filter.append(bandpass)
316
- target_filter.append(bandpass)
326
+
327
+ if not args.no_filter_target:
328
+ target_filter.append(bandpass)
317
329
 
318
330
  if args.whiten_spectrum:
319
331
  whitening_filter = LinearWhiteningFilter()
@@ -335,7 +347,10 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
335
347
 
336
348
 
337
349
  def parse_args():
338
- parser = argparse.ArgumentParser(description="Perform template matching.")
350
+ parser = argparse.ArgumentParser(
351
+ description="Perform template matching.",
352
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
353
+ )
339
354
 
340
355
  io_group = parser.add_argument_group("Input / Output")
341
356
  io_group.add_argument(
@@ -384,8 +399,7 @@ def parse_args():
384
399
  dest="invert_target_contrast",
385
400
  action="store_true",
386
401
  default=False,
387
- help="Invert the target's contrast and rescale linearly between zero and one. "
388
- "This option is intended for targets where templates to-be-matched have "
402
+ help="Invert the target's contrast for cases where templates to-be-matched have "
389
403
  "negative values, e.g. tomograms.",
390
404
  )
391
405
  io_group.add_argument(
@@ -405,13 +419,6 @@ def parse_args():
405
419
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
406
420
  help="Template matching scoring function.",
407
421
  )
408
- scoring_group.add_argument(
409
- "-p",
410
- dest="peak_calling",
411
- action="store_true",
412
- default=False,
413
- help="Perform peak calling instead of score aggregation.",
414
- )
415
422
 
416
423
  angular_group = parser.add_argument_group("Angular Sampling")
417
424
  angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
@@ -445,7 +452,7 @@ def parse_args():
445
452
  type=check_positive,
446
453
  default=360.0,
447
454
  required=False,
448
- help="Sampling angle along the z-axis of the cone. Defaults to 360.",
455
+ help="Sampling angle along the z-axis of the cone.",
449
456
  )
450
457
  angular_group.add_argument(
451
458
  "--axis_sampling",
@@ -513,8 +520,7 @@ def parse_args():
513
520
  required=False,
514
521
  type=float,
515
522
  default=0.85,
516
- help="Fraction of available memory that can be used. Defaults to 0.85 and is "
517
- "ignored if --ram is set",
523
+ help="Fraction of available memory to be used. Ignored if --ram is set.",
518
524
  )
519
525
  computation_group.add_argument(
520
526
  "--temp_directory",
@@ -522,7 +528,13 @@ def parse_args():
522
528
  default=None,
523
529
  help="Directory for temporary objects. Faster I/O improves runtime.",
524
530
  )
525
-
531
+ computation_group.add_argument(
532
+ "--backend",
533
+ dest="backend",
534
+ default=None,
535
+ choices=be.available_backends(),
536
+ help="[Expert] Overwrite default computation backend.",
537
+ )
526
538
  filter_group = parser.add_argument_group("Filters")
527
539
  filter_group.add_argument(
528
540
  "--lowpass",
@@ -552,9 +564,9 @@ def parse_args():
552
564
  dest="pass_format",
553
565
  type=str,
554
566
  required=False,
567
+ default="sampling_rate",
555
568
  choices=["sampling_rate", "voxel", "frequency"],
556
- help="How values passed to --lowpass and --highpass should be interpreted. "
557
- "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
569
+ help="How values passed to --lowpass and --highpass should be interpreted. ",
558
570
  )
559
571
  filter_group.add_argument(
560
572
  "--whiten_spectrum",
@@ -613,6 +625,13 @@ def parse_args():
613
625
  required=False,
614
626
  help="Analogous to --interpolation_order but for reconstruction.",
615
627
  )
628
+ filter_group.add_argument(
629
+ "--no_filter_target",
630
+ dest="no_filter_target",
631
+ action="store_true",
632
+ default=False,
633
+ help="Whether to not apply potential filters to the target.",
634
+ )
616
635
 
617
636
  ctf_group = parser.add_argument_group("Contrast Transfer Function")
618
637
  ctf_group.add_argument(
@@ -647,7 +666,7 @@ def parse_args():
647
666
  type=float,
648
667
  required=False,
649
668
  default=300,
650
- help="Acceleration voltage in kV, defaults to 300.",
669
+ help="Acceleration voltage in kV.",
651
670
  )
652
671
  ctf_group.add_argument(
653
672
  "--spherical_aberration",
@@ -663,14 +682,14 @@ def parse_args():
663
682
  type=float,
664
683
  required=False,
665
684
  default=0.07,
666
- help="Amplitude contrast, defaults to 0.07.",
685
+ help="Amplitude contrast.",
667
686
  )
668
687
  ctf_group.add_argument(
669
688
  "--no_flip_phase",
670
689
  dest="no_flip_phase",
671
690
  action="store_false",
672
691
  required=False,
673
- help="Whether the phase of the computed CTF should not be flipped.",
692
+ help="Do not perform phase-flipping CTF correction.",
674
693
  )
675
694
  ctf_group.add_argument(
676
695
  "--correct_defocus_gradient",
@@ -682,22 +701,6 @@ def parse_args():
682
701
  )
683
702
 
684
703
  performance_group = parser.add_argument_group("Performance")
685
- performance_group.add_argument(
686
- "--cutoff_target",
687
- dest="cutoff_target",
688
- type=float,
689
- required=False,
690
- default=None,
691
- help="Target contour level (used for cropping).",
692
- )
693
- performance_group.add_argument(
694
- "--cutoff_template",
695
- dest="cutoff_template",
696
- type=float,
697
- required=False,
698
- default=None,
699
- help="Template contour level (used for cropping).",
700
- )
701
704
  performance_group.add_argument(
702
705
  "--no_centering",
703
706
  dest="no_centering",
@@ -705,21 +708,28 @@ def parse_args():
705
708
  help="Assumes the template is already centered and omits centering.",
706
709
  )
707
710
  performance_group.add_argument(
708
- "--no_edge_padding",
709
- dest="no_edge_padding",
711
+ "--pad_edges",
712
+ dest="pad_edges",
710
713
  action="store_true",
711
714
  default=False,
712
- help="Whether to not pad the edges of the target. Can be set if the target"
713
- " has a well defined bounding box, e.g. a masked reconstruction.",
715
+ help="Whether to pad the edges of the target. Useful if the target does not "
716
+ "a well-defined bounding box. Defaults to True if splitting is required.",
714
717
  )
715
718
  performance_group.add_argument(
716
- "--no_fourier_padding",
717
- dest="no_fourier_padding",
719
+ "--pad_fourier",
720
+ dest="pad_fourier",
718
721
  action="store_true",
719
722
  default=False,
720
723
  help="Whether input arrays should not be zero-padded to full convolution shape "
721
- "for numerical stability. When working with very large targets, e.g. tomograms, "
722
- "it is safe to use this flag and benefit from the performance gain.",
724
+ "for numerical stability. Typically only useful when working with small data.",
725
+ )
726
+ performance_group.add_argument(
727
+ "--pad_filter",
728
+ dest="pad_filter",
729
+ action="store_true",
730
+ default=False,
731
+ help="Pads the filter to the shape of the target. Particularly useful for fast "
732
+ "oscilating filters to avoid aliasing effects.",
723
733
  )
724
734
  performance_group.add_argument(
725
735
  "--interpolation_order",
@@ -727,8 +737,7 @@ def parse_args():
727
737
  required=False,
728
738
  type=int,
729
739
  default=3,
730
- help="Spline interpolation used for template rotations. If less than zero "
731
- "no interpolation is performed.",
740
+ help="Spline interpolation used for rotations.",
732
741
  )
733
742
  performance_group.add_argument(
734
743
  "--use_mixed_precision",
@@ -755,7 +764,20 @@ def parse_args():
755
764
  default=0,
756
765
  help="Minimum template matching scores to consider for analysis.",
757
766
  )
758
-
767
+ analyzer_group.add_argument(
768
+ "-p",
769
+ dest="peak_calling",
770
+ action="store_true",
771
+ default=False,
772
+ help="Perform peak calling instead of score aggregation.",
773
+ )
774
+ analyzer_group.add_argument(
775
+ "--number_of_peaks",
776
+ dest="number_of_peaks",
777
+ action="store_true",
778
+ default=1000,
779
+ help="Number of peaks to call, 1000 by default.",
780
+ )
759
781
  args = parser.parse_args()
760
782
  args.version = __version__
761
783
 
@@ -770,9 +792,6 @@ def parse_args():
770
792
 
771
793
  os.environ["TMPDIR"] = args.temp_directory
772
794
 
773
- args.pad_target_edges = not args.no_edge_padding
774
- args.pad_fourier = not args.no_fourier_padding
775
-
776
795
  if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
777
796
  raise ValueError(
778
797
  f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
@@ -827,7 +846,9 @@ def main():
827
846
  )
828
847
 
829
848
  if target.sampling_rate.size == template.sampling_rate.size:
830
- if not np.allclose(target.sampling_rate, template.sampling_rate):
849
+ if not np.allclose(
850
+ np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
851
+ ):
831
852
  print(
832
853
  f"Resampling template to {target.sampling_rate}. "
833
854
  "Consider providing a template with the same sampling rate as the target."
@@ -842,9 +863,6 @@ def main():
842
863
  )
843
864
 
844
865
  initial_shape = target.shape
845
- is_cropped = crop_data(
846
- data=target, data_mask=target_mask, cutoff=args.cutoff_target
847
- )
848
866
  print_block(
849
867
  name="Target",
850
868
  data={
@@ -853,13 +871,6 @@ def main():
853
871
  "Final Shape": target.shape,
854
872
  },
855
873
  )
856
- if is_cropped:
857
- args.target = generate_tempfile_name(suffix=".mrc")
858
- target.to_file(args.target)
859
-
860
- if target_mask:
861
- args.target_mask = generate_tempfile_name(suffix=".mrc")
862
- target_mask.to_file(args.target_mask)
863
874
 
864
875
  if target_mask:
865
876
  print_block(
@@ -872,8 +883,6 @@ def main():
872
883
  )
873
884
 
874
885
  initial_shape = template.shape
875
- _ = crop_data(data=template, data_mask=template_mask, cutoff=args.cutoff_template)
876
-
877
886
  translation = np.zeros(len(template.shape), dtype=np.float32)
878
887
  if not args.no_centering:
879
888
  template, translation = template.centered(0)
@@ -921,47 +930,62 @@ def main():
921
930
 
922
931
  if args.scramble_phases:
923
932
  template.data = scramble_phases(
924
- template.data, noise_proportion=1.0, normalize_power=True
933
+ template.data, noise_proportion=1.0, normalize_power=False
925
934
  )
926
935
 
927
- available_memory = backend.get_available_memory()
936
+ # Determine suitable backend for the selected operation
937
+ available_backends = be.available_backends()
938
+ if args.backend is not None:
939
+ req_backend = args.backend
940
+ if req_backend not in available_backends:
941
+ raise ValueError("Requested backend is not available.")
942
+ available_backends = [req_backend]
943
+
944
+ be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
928
945
  if args.use_gpu:
929
946
  args.cores = len(args.gpu_indices)
930
- has_torch = importlib.util.find_spec("torch") is not None
931
- has_cupy = importlib.util.find_spec("cupy") is not None
947
+ be_selection = ("pytorch", "cupy", "jax")
948
+ if args.use_mixed_precision:
949
+ be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
932
950
 
933
- if not has_torch and not has_cupy:
934
- raise ValueError(
935
- "Found neither CuPy nor PyTorch installation. You need to install"
936
- " either to enable GPU support."
951
+ available_backends = [x for x in available_backends if x in be_selection]
952
+ if args.peak_calling:
953
+ if "jax" in available_backends:
954
+ available_backends.remove("jax")
955
+ if args.use_gpu and "pytorch" in available_backends:
956
+ available_backends = ("pytorch",)
957
+
958
+ # dim_match = len(template.shape) == len(target.shape) <= 3
959
+ # if dim_match and args.use_gpu and "jax" in available_backends:
960
+ # args.interpolation_order = 1
961
+ # available_backends = ["jax"]
962
+
963
+ backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
964
+ if args.use_gpu:
965
+ backend_preference = ("cupy", "pytorch", "jax")
966
+ for pref in backend_preference:
967
+ if pref not in available_backends:
968
+ continue
969
+ be.change_backend(pref)
970
+ if pref == "pytorch":
971
+ be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
972
+
973
+ if args.use_mixed_precision:
974
+ be.change_backend(
975
+ backend_name=pref,
976
+ default_dtype=be._array_backend.float16,
977
+ complex_dtype=be._array_backend.complex64,
978
+ default_dtype_int=be._array_backend.int16,
937
979
  )
980
+ break
938
981
 
939
- if args.peak_calling:
940
- preferred_backend = "pytorch"
941
- if not has_torch:
942
- preferred_backend = "cupy"
943
- backend.change_backend(backend_name=preferred_backend, device="cuda")
944
- else:
945
- preferred_backend = "cupy"
946
- if not has_cupy:
947
- preferred_backend = "pytorch"
948
- backend.change_backend(backend_name=preferred_backend, device="cuda")
949
- if args.use_mixed_precision and preferred_backend == "pytorch":
950
- raise NotImplementedError(
951
- "pytorch backend does not yet support mixed precision."
952
- " Consider installing CuPy to enable this feature."
953
- )
954
- elif args.use_mixed_precision:
955
- backend.change_backend(
956
- backend_name="cupy",
957
- default_dtype=backend._array_backend.float16,
958
- complex_dtype=backend._array_backend.complex64,
959
- default_dtype_int=backend._array_backend.int16,
960
- )
961
- available_memory = backend.get_available_memory() * args.cores
962
- if preferred_backend == "pytorch" and args.interpolation_order == 3:
963
- args.interpolation_order = 1
982
+ if pref == "pytorch" and args.interpolation_order == 3:
983
+ warnings.warn(
984
+ "Pytorch does not support --interpolation_order 3, setting it to 1."
985
+ )
986
+ args.interpolation_order = 1
964
987
 
988
+ available_memory = be.get_available_memory() * be.device_count()
965
989
  if args.memory is None:
966
990
  args.memory = int(args.memory_scaling * available_memory)
967
991
 
@@ -978,70 +1002,49 @@ def main():
978
1002
  rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
979
1003
  )
980
1004
 
981
- template_filter, target_filter = setup_filter(args, template, target)
982
- matching_data.template_filter = template_filter
983
- matching_data.target_filter = target_filter
984
-
985
- template_box = matching_data._output_template_shape
986
- if not args.pad_fourier:
987
- template_box = np.ones(len(template_box), dtype=int)
988
-
989
- target_padding = np.zeros(
990
- (backend.size(matching_data._output_template_shape)), dtype=int
991
- )
992
- if args.pad_target_edges:
993
- target_padding = matching_data._output_template_shape
994
-
995
- splits, schedule = compute_parallelization_schedule(
996
- shape1=target.shape,
997
- shape2=tuple(int(x) for x in template_box),
998
- shape1_padding=tuple(int(x) for x in target_padding),
999
- max_cores=args.cores,
1000
- max_ram=args.memory,
1001
- split_only_outer=args.use_gpu,
1002
- matching_method=args.score,
1003
- analyzer_method=callback_class.__name__,
1004
- backend=backend._backend_name,
1005
- float_nbytes=backend.datatype_bytes(backend._float_dtype),
1006
- complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
1007
- integer_nbytes=backend.datatype_bytes(backend._int_dtype),
1005
+ matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
1006
+ matching_data.template_filter, matching_data.target_filter = setup_filter(
1007
+ args, template, target
1008
1008
  )
1009
1009
 
1010
- if splits is None:
1011
- print(
1012
- "Found no suitable parallelization schedule. Consider increasing"
1013
- " available RAM or decreasing number of cores."
1014
- )
1015
- exit(-1)
1010
+ splits, schedule = compute_schedule(args, target, matching_data, callback_class)
1016
1011
 
1017
- matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
1018
1012
  n_splits = np.prod(list(splits.values()))
1019
1013
  target_split = ", ".join(
1020
1014
  [":".join([str(x) for x in axis]) for axis in splits.items()]
1021
1015
  )
1022
1016
  gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
1023
1017
  options = {
1024
- "CPU Cores": args.cores,
1025
- "Run on GPU": f"{args.use_gpu} [N={gpus_used}]",
1026
- "Use Mixed Precision": args.use_mixed_precision,
1027
- "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1028
- "Temporary Directory": args.temp_directory,
1029
- "Extend Fourier Grid": not args.no_fourier_padding,
1030
- "Extend Target Edges": not args.no_edge_padding,
1031
- "Interpolation Order": args.interpolation_order,
1032
- "Score": f"{args.score}",
1033
- "Setup Function": f"{get_func_fullname(matching_setup)}",
1034
- "Scoring Function": f"{get_func_fullname(matching_score)}",
1035
1018
  "Angular Sampling": f"{args.angular_sampling}"
1036
1019
  f" [{matching_data.rotations.shape[0]} rotations]",
1020
+ "Center Template": not args.no_centering,
1037
1021
  "Scramble Template": args.scramble_phases,
1038
- "Target Splits": f"{target_split} [N={n_splits}]",
1022
+ "Invert Contrast": args.invert_target_contrast,
1023
+ "Extend Fourier Grid": args.pad_fourier,
1024
+ "Extend Target Edges": args.pad_edges,
1025
+ "Interpolation Order": args.interpolation_order,
1026
+ "Setup Function": f"{get_func_fullname(matching_setup)}",
1027
+ "Scoring Function": f"{get_func_fullname(matching_score)}",
1039
1028
  }
1040
1029
 
1041
1030
  print_block(
1042
- name="Template Matching Options",
1031
+ name="Template Matching",
1043
1032
  data=options,
1044
- label_width=max(len(key) for key in options.keys()) + 2,
1033
+ label_width=max(len(key) for key in options.keys()) + 3,
1034
+ )
1035
+
1036
+ compute_options = {
1037
+ "Backend": be._BACKEND_REGISTRY[be._backend_name],
1038
+ "Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
1039
+ "Use Mixed Precision": args.use_mixed_precision,
1040
+ "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1041
+ "Temporary Directory": args.temp_directory,
1042
+ "Target Splits": f"{target_split} [N={n_splits}]",
1043
+ }
1044
+ print_block(
1045
+ name="Computation",
1046
+ data=compute_options,
1047
+ label_width=max(len(key) for key in options.keys()) + 3,
1045
1048
  )
1046
1049
 
1047
1050
  filter_args = {
@@ -1054,12 +1057,13 @@ def main():
1054
1057
  "Tilt Angles": args.tilt_angles,
1055
1058
  "Tilt Weighting": args.tilt_weighting,
1056
1059
  "Reconstruction Filter": args.reconstruction_filter,
1060
+ "Extend Filter Grid": args.pad_filter,
1057
1061
  }
1058
1062
  if args.ctf_file is not None or args.defocus is not None:
1059
1063
  filter_args["CTF File"] = args.ctf_file
1060
1064
  filter_args["Defocus"] = args.defocus
1061
1065
  filter_args["Phase Shift"] = args.phase_shift
1062
- filter_args["No Flip Phase"] = args.no_flip_phase
1066
+ filter_args["Flip Phase"] = args.no_flip_phase
1063
1067
  filter_args["Acceleration Voltage"] = args.acceleration_voltage
1064
1068
  filter_args["Spherical Aberration"] = args.spherical_aberration
1065
1069
  filter_args["Amplitude Contrast"] = args.amplitude_contrast
@@ -1070,20 +1074,19 @@ def main():
1070
1074
  print_block(
1071
1075
  name="Filters",
1072
1076
  data=filter_args,
1073
- label_width=max(len(key) for key in options.keys()) + 2,
1077
+ label_width=max(len(key) for key in options.keys()) + 3,
1074
1078
  )
1075
1079
 
1076
1080
  analyzer_args = {
1077
1081
  "score_threshold": args.score_threshold,
1078
- "number_of_peaks": 1000,
1079
- "convolution_mode": "valid",
1082
+ "number_of_peaks": args.number_of_peaks,
1083
+ "min_distance": max(template.shape) // 3,
1080
1084
  "use_memmap": args.use_memmap,
1081
1085
  }
1082
- analyzer_args = {"Analyzer": callback_class, **analyzer_args}
1083
1086
  print_block(
1084
- name="Score Analysis Options",
1085
- data=analyzer_args,
1086
- label_width=max(len(key) for key in options.keys()) + 2,
1087
+ name="Analyzer",
1088
+ data={"Analyzer": callback_class, **analyzer_args},
1089
+ label_width=max(len(key) for key in options.keys()) + 3,
1087
1090
  )
1088
1091
  print("\n" + "-" * 80)
1089
1092
 
@@ -1102,8 +1105,9 @@ def main():
1102
1105
  callback_class=callback_class,
1103
1106
  callback_class_args=analyzer_args,
1104
1107
  target_splits=splits,
1105
- pad_target_edges=args.pad_target_edges,
1108
+ pad_target_edges=args.pad_edges,
1106
1109
  pad_fourier=args.pad_fourier,
1110
+ pad_template_filter=args.pad_filter,
1107
1111
  interpolation_order=args.interpolation_order,
1108
1112
  )
1109
1113
 
@@ -1113,19 +1117,18 @@ def main():
1113
1117
  candidates[0] *= target_mask.data
1114
1118
  with warnings.catch_warnings():
1115
1119
  warnings.simplefilter("ignore", category=UserWarning)
1116
- nbytes = backend.datatype_bytes(backend._float_dtype)
1120
+ nbytes = be.datatype_bytes(be._float_dtype)
1117
1121
  dtype = np.float32 if nbytes == 4 else np.float16
1118
1122
  rot_dim = matching_data.rotations.shape[1]
1119
1123
  candidates[3] = {
1120
- x: euler_from_rotationmatrix(
1121
- np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1122
- )
1124
+ x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1123
1125
  for i, x in candidates[3].items()
1124
1126
  }
1125
1127
  candidates.append((target.origin, template.origin, template.sampling_rate, args))
1126
1128
  write_pickle(data=candidates, filename=args.output)
1127
1129
 
1128
1130
  runtime = time() - start
1131
+ print("\n" + "-" * 80)
1129
1132
  print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
1130
1133
 
1131
1134