pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
scripts/match_template.py CHANGED
@@ -8,7 +8,6 @@
8
8
  import os
9
9
  import argparse
10
10
  import warnings
11
- import importlib.util
12
11
  from sys import exit
13
12
  from time import time
14
13
  from typing import Tuple
@@ -22,7 +21,6 @@ from tme.matching_utils import (
22
21
  get_rotation_matrices,
23
22
  get_rotations_around_vector,
24
23
  compute_parallelization_schedule,
25
- euler_from_rotationmatrix,
26
24
  scramble_phases,
27
25
  generate_tempfile_name,
28
26
  write_pickle,
@@ -33,7 +31,7 @@ from tme.analyzer import (
33
31
  MaxScoreOverRotations,
34
32
  PeakCallerMaximumFilter,
35
33
  )
36
- from tme.backends import backend
34
+ from tme.backends import backend as be
37
35
  from tme.preprocessing import Compose
38
36
 
39
37
 
@@ -52,7 +50,7 @@ def print_block(name: str, data: dict, label_width=20) -> None:
52
50
 
53
51
  def print_entry() -> None:
54
52
  width = 80
55
- text = f" pyTME v{__version__} "
53
+ text = f" pytme v{__version__} "
56
54
  padding_total = width - len(text) - 2
57
55
  padding_left = padding_total // 2
58
56
  padding_right = padding_total - padding_left
@@ -273,7 +271,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
273
271
  return_real_fourier=True,
274
272
  )
275
273
  ctf.sampling_rate = template.sampling_rate
276
- ctf.flip_phase = not args.no_flip_phase
274
+ ctf.flip_phase = args.no_flip_phase
277
275
  ctf.amplitude_contrast = args.amplitude_contrast
278
276
  ctf.spherical_aberration = args.spherical_aberration
279
277
  ctf.acceleration_voltage = args.acceleration_voltage * 1e3
@@ -306,6 +304,12 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
306
304
  if highpass is not None:
307
305
  highpass = np.max(np.divide(template.sampling_rate, highpass))
308
306
 
307
+ try:
308
+ if args.lowpass >= args.highpass:
309
+ warnings.warn("--lowpass should be smaller than --highpass.")
310
+ except Exception:
311
+ pass
312
+
309
313
  bandpass = BandPassFilter(
310
314
  use_gaussian=args.no_pass_smooth,
311
315
  lowpass=lowpass,
@@ -313,7 +317,9 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
313
317
  sampling_rate=template.sampling_rate,
314
318
  )
315
319
  template_filter.append(bandpass)
316
- target_filter.append(bandpass)
320
+
321
+ if not args.no_filter_target:
322
+ target_filter.append(bandpass)
317
323
 
318
324
  if args.whiten_spectrum:
319
325
  whitening_filter = LinearWhiteningFilter()
@@ -335,7 +341,10 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
335
341
 
336
342
 
337
343
  def parse_args():
338
- parser = argparse.ArgumentParser(description="Perform template matching.")
344
+ parser = argparse.ArgumentParser(
345
+ description="Perform template matching.",
346
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
347
+ )
339
348
 
340
349
  io_group = parser.add_argument_group("Input / Output")
341
350
  io_group.add_argument(
@@ -405,13 +414,6 @@ def parse_args():
405
414
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
406
415
  help="Template matching scoring function.",
407
416
  )
408
- scoring_group.add_argument(
409
- "-p",
410
- dest="peak_calling",
411
- action="store_true",
412
- default=False,
413
- help="Perform peak calling instead of score aggregation.",
414
- )
415
417
 
416
418
  angular_group = parser.add_argument_group("Angular Sampling")
417
419
  angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
@@ -445,7 +447,7 @@ def parse_args():
445
447
  type=check_positive,
446
448
  default=360.0,
447
449
  required=False,
448
- help="Sampling angle along the z-axis of the cone. Defaults to 360.",
450
+ help="Sampling angle along the z-axis of the cone.",
449
451
  )
450
452
  angular_group.add_argument(
451
453
  "--axis_sampling",
@@ -513,8 +515,7 @@ def parse_args():
513
515
  required=False,
514
516
  type=float,
515
517
  default=0.85,
516
- help="Fraction of available memory that can be used. Defaults to 0.85 and is "
517
- "ignored if --ram is set",
518
+ help="Fraction of available memory to be used. Ignored if --ram is set.",
518
519
  )
519
520
  computation_group.add_argument(
520
521
  "--temp_directory",
@@ -522,7 +523,13 @@ def parse_args():
522
523
  default=None,
523
524
  help="Directory for temporary objects. Faster I/O improves runtime.",
524
525
  )
525
-
526
+ computation_group.add_argument(
527
+ "--backend",
528
+ dest="backend",
529
+ default=None,
530
+ choices=be.available_backends(),
531
+ help="[Expert] Overwrite default computation backend.",
532
+ )
526
533
  filter_group = parser.add_argument_group("Filters")
527
534
  filter_group.add_argument(
528
535
  "--lowpass",
@@ -552,9 +559,9 @@ def parse_args():
552
559
  dest="pass_format",
553
560
  type=str,
554
561
  required=False,
562
+ default="sampling_rate",
555
563
  choices=["sampling_rate", "voxel", "frequency"],
556
- help="How values passed to --lowpass and --highpass should be interpreted. "
557
- "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
564
+ help="How values passed to --lowpass and --highpass should be interpreted. ",
558
565
  )
559
566
  filter_group.add_argument(
560
567
  "--whiten_spectrum",
@@ -613,6 +620,13 @@ def parse_args():
613
620
  required=False,
614
621
  help="Analogous to --interpolation_order but for reconstruction.",
615
622
  )
623
+ filter_group.add_argument(
624
+ "--no_filter_target",
625
+ dest="no_filter_target",
626
+ action="store_true",
627
+ default=False,
628
+ help="Whether to not apply potential filters to the target.",
629
+ )
616
630
 
617
631
  ctf_group = parser.add_argument_group("Contrast Transfer Function")
618
632
  ctf_group.add_argument(
@@ -647,7 +661,7 @@ def parse_args():
647
661
  type=float,
648
662
  required=False,
649
663
  default=300,
650
- help="Acceleration voltage in kV, defaults to 300.",
664
+ help="Acceleration voltage in kV.",
651
665
  )
652
666
  ctf_group.add_argument(
653
667
  "--spherical_aberration",
@@ -663,14 +677,14 @@ def parse_args():
663
677
  type=float,
664
678
  required=False,
665
679
  default=0.07,
666
- help="Amplitude contrast, defaults to 0.07.",
680
+ help="Amplitude contrast.",
667
681
  )
668
682
  ctf_group.add_argument(
669
683
  "--no_flip_phase",
670
684
  dest="no_flip_phase",
671
685
  action="store_false",
672
686
  required=False,
673
- help="Whether the phase of the computed CTF should not be flipped.",
687
+ help="Do not perform phase-flipping CTF correction.",
674
688
  )
675
689
  ctf_group.add_argument(
676
690
  "--correct_defocus_gradient",
@@ -721,14 +735,22 @@ def parse_args():
721
735
  "for numerical stability. When working with very large targets, e.g. tomograms, "
722
736
  "it is safe to use this flag and benefit from the performance gain.",
723
737
  )
738
+ performance_group.add_argument(
739
+ "--no_filter_padding",
740
+ dest="no_filter_padding",
741
+ action="store_true",
742
+ default=False,
743
+ help="Omits padding of optional template filters. Particularly effective when "
744
+ "the target is much larger than the template. However, for fast osciliating "
745
+ "filters setting this flag can introduce aliasing effects.",
746
+ )
724
747
  performance_group.add_argument(
725
748
  "--interpolation_order",
726
749
  dest="interpolation_order",
727
750
  required=False,
728
751
  type=int,
729
752
  default=3,
730
- help="Spline interpolation used for template rotations. If less than zero "
731
- "no interpolation is performed.",
753
+ help="Spline interpolation used for rotations.",
732
754
  )
733
755
  performance_group.add_argument(
734
756
  "--use_mixed_precision",
@@ -755,7 +777,20 @@ def parse_args():
755
777
  default=0,
756
778
  help="Minimum template matching scores to consider for analysis.",
757
779
  )
758
-
780
+ analyzer_group.add_argument(
781
+ "-p",
782
+ dest="peak_calling",
783
+ action="store_true",
784
+ default=False,
785
+ help="Perform peak calling instead of score aggregation.",
786
+ )
787
+ analyzer_group.add_argument(
788
+ "--number_of_peaks",
789
+ dest="number_of_peaks",
790
+ action="store_true",
791
+ default=1000,
792
+ help="Number of peaks to call, 1000 by default.",
793
+ )
759
794
  args = parser.parse_args()
760
795
  args.version = __version__
761
796
 
@@ -924,44 +959,56 @@ def main():
924
959
  template.data, noise_proportion=1.0, normalize_power=True
925
960
  )
926
961
 
927
- available_memory = backend.get_available_memory()
962
+ # Determine suitable backend for the selected operation
963
+ available_backends = be.available_backends()
964
+ if args.backend is not None:
965
+ req_backend = args.backend
966
+ if req_backend not in available_backends:
967
+ raise ValueError("Requested backend is not available.")
968
+ available_backends = [req_backend]
969
+
970
+ be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
928
971
  if args.use_gpu:
929
972
  args.cores = len(args.gpu_indices)
930
- has_torch = importlib.util.find_spec("torch") is not None
931
- has_cupy = importlib.util.find_spec("cupy") is not None
932
-
933
- if not has_torch and not has_cupy:
934
- raise ValueError(
935
- "Found neither CuPy nor PyTorch installation. You need to install"
936
- " either to enable GPU support."
937
- )
973
+ be_selection = ("pytorch", "cupy", "jax")
974
+ if args.use_mixed_precision:
975
+ be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
938
976
 
939
- if args.peak_calling:
940
- preferred_backend = "pytorch"
941
- if not has_torch:
942
- preferred_backend = "cupy"
943
- backend.change_backend(backend_name=preferred_backend, device="cuda")
944
- else:
945
- preferred_backend = "cupy"
946
- if not has_cupy:
947
- preferred_backend = "pytorch"
948
- backend.change_backend(backend_name=preferred_backend, device="cuda")
949
- if args.use_mixed_precision and preferred_backend == "pytorch":
977
+ available_backends = [x for x in available_backends if x in be_selection]
978
+ if args.peak_calling:
979
+ if "jax" in available_backends:
980
+ available_backends.remove("jax")
981
+ if args.use_gpu and "pytorch" in available_backends:
982
+ available_backends = ("pytorch",)
983
+ if args.interpolation_order == 3:
950
984
  raise NotImplementedError(
951
- "pytorch backend does not yet support mixed precision."
952
- " Consider installing CuPy to enable this feature."
985
+ "Pytorch does not support --interpolation_order 3, 1 is supported."
953
986
  )
954
- elif args.use_mixed_precision:
955
- backend.change_backend(
956
- backend_name="cupy",
957
- default_dtype=backend._array_backend.float16,
958
- complex_dtype=backend._array_backend.complex64,
959
- default_dtype_int=backend._array_backend.int16,
960
- )
961
- available_memory = backend.get_available_memory() * args.cores
962
- if preferred_backend == "pytorch" and args.interpolation_order == 3:
963
- args.interpolation_order = 1
987
+ # dim_match = len(template.shape) == len(target.shape) <= 3
988
+ # if dim_match and args.use_gpu and "jax" in available_backends:
989
+ # args.interpolation_order = 1
990
+ # available_backends = ["jax"]
964
991
 
992
+ backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
993
+ if args.use_gpu:
994
+ backend_preference = ("cupy", "pytorch", "jax")
995
+ for pref in backend_preference:
996
+ if pref not in available_backends:
997
+ continue
998
+ be.change_backend(pref)
999
+ if pref == "pytorch":
1000
+ be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
1001
+
1002
+ if args.use_mixed_precision:
1003
+ be.change_backend(
1004
+ backend_name=pref,
1005
+ default_dtype=be._array_backend.float16,
1006
+ complex_dtype=be._array_backend.complex64,
1007
+ default_dtype_int=be._array_backend.int16,
1008
+ )
1009
+ break
1010
+
1011
+ available_memory = be.get_available_memory() * be.device_count()
965
1012
  if args.memory is None:
966
1013
  args.memory = int(args.memory_scaling * available_memory)
967
1014
 
@@ -978,17 +1025,15 @@ def main():
978
1025
  rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
979
1026
  )
980
1027
 
981
- template_filter, target_filter = setup_filter(args, template, target)
982
- matching_data.template_filter = template_filter
983
- matching_data.target_filter = target_filter
1028
+ matching_data.template_filter, matching_data.target_filter = setup_filter(
1029
+ args, template, target
1030
+ )
984
1031
 
985
1032
  template_box = matching_data._output_template_shape
986
1033
  if not args.pad_fourier:
987
- template_box = np.ones(len(template_box), dtype=int)
1034
+ template_box = tuple(0 for _ in range(len(template_box)))
988
1035
 
989
- target_padding = np.zeros(
990
- (backend.size(matching_data._output_template_shape)), dtype=int
991
- )
1036
+ target_padding = tuple(0 for _ in range(len(template_box)))
992
1037
  if args.pad_target_edges:
993
1038
  target_padding = matching_data._output_template_shape
994
1039
 
@@ -1001,10 +1046,10 @@ def main():
1001
1046
  split_only_outer=args.use_gpu,
1002
1047
  matching_method=args.score,
1003
1048
  analyzer_method=callback_class.__name__,
1004
- backend=backend._backend_name,
1005
- float_nbytes=backend.datatype_bytes(backend._float_dtype),
1006
- complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
1007
- integer_nbytes=backend.datatype_bytes(backend._int_dtype),
1049
+ backend=be._backend_name,
1050
+ float_nbytes=be.datatype_bytes(be._float_dtype),
1051
+ complex_nbytes=be.datatype_bytes(be._complex_dtype),
1052
+ integer_nbytes=be.datatype_bytes(be._int_dtype),
1008
1053
  )
1009
1054
 
1010
1055
  if splits is None:
@@ -1021,27 +1066,36 @@ def main():
1021
1066
  )
1022
1067
  gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
1023
1068
  options = {
1024
- "CPU Cores": args.cores,
1025
- "Run on GPU": f"{args.use_gpu} [N={gpus_used}]",
1026
- "Use Mixed Precision": args.use_mixed_precision,
1027
- "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1028
- "Temporary Directory": args.temp_directory,
1069
+ "Angular Sampling": f"{args.angular_sampling}"
1070
+ f" [{matching_data.rotations.shape[0]} rotations]",
1071
+ "Center Template": not args.no_centering,
1072
+ "Scramble Template": args.scramble_phases,
1073
+ "Invert Contrast": args.invert_target_contrast,
1029
1074
  "Extend Fourier Grid": not args.no_fourier_padding,
1030
1075
  "Extend Target Edges": not args.no_edge_padding,
1031
1076
  "Interpolation Order": args.interpolation_order,
1032
- "Score": f"{args.score}",
1033
1077
  "Setup Function": f"{get_func_fullname(matching_setup)}",
1034
1078
  "Scoring Function": f"{get_func_fullname(matching_score)}",
1035
- "Angular Sampling": f"{args.angular_sampling}"
1036
- f" [{matching_data.rotations.shape[0]} rotations]",
1037
- "Scramble Template": args.scramble_phases,
1038
- "Target Splits": f"{target_split} [N={n_splits}]",
1039
1079
  }
1040
1080
 
1041
1081
  print_block(
1042
- name="Template Matching Options",
1082
+ name="Template Matching",
1043
1083
  data=options,
1044
- label_width=max(len(key) for key in options.keys()) + 2,
1084
+ label_width=max(len(key) for key in options.keys()) + 3,
1085
+ )
1086
+
1087
+ compute_options = {
1088
+ "Backend": be._BACKEND_REGISTRY[be._backend_name],
1089
+ "Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
1090
+ "Use Mixed Precision": args.use_mixed_precision,
1091
+ "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1092
+ "Temporary Directory": args.temp_directory,
1093
+ "Target Splits": f"{target_split} [N={n_splits}]",
1094
+ }
1095
+ print_block(
1096
+ name="Computation",
1097
+ data=compute_options,
1098
+ label_width=max(len(key) for key in options.keys()) + 3,
1045
1099
  )
1046
1100
 
1047
1101
  filter_args = {
@@ -1059,7 +1113,7 @@ def main():
1059
1113
  filter_args["CTF File"] = args.ctf_file
1060
1114
  filter_args["Defocus"] = args.defocus
1061
1115
  filter_args["Phase Shift"] = args.phase_shift
1062
- filter_args["No Flip Phase"] = args.no_flip_phase
1116
+ filter_args["Flip Phase"] = args.no_flip_phase
1063
1117
  filter_args["Acceleration Voltage"] = args.acceleration_voltage
1064
1118
  filter_args["Spherical Aberration"] = args.spherical_aberration
1065
1119
  filter_args["Amplitude Contrast"] = args.amplitude_contrast
@@ -1070,20 +1124,20 @@ def main():
1070
1124
  print_block(
1071
1125
  name="Filters",
1072
1126
  data=filter_args,
1073
- label_width=max(len(key) for key in options.keys()) + 2,
1127
+ label_width=max(len(key) for key in options.keys()) + 3,
1074
1128
  )
1075
1129
 
1076
1130
  analyzer_args = {
1077
1131
  "score_threshold": args.score_threshold,
1078
- "number_of_peaks": 1000,
1079
- "convolution_mode": "valid",
1132
+ "number_of_peaks": args.number_of_peaks,
1133
+ "min_distance": max(template.shape) // 2,
1134
+ "min_boundary_distance": max(template.shape) // 2,
1080
1135
  "use_memmap": args.use_memmap,
1081
1136
  }
1082
- analyzer_args = {"Analyzer": callback_class, **analyzer_args}
1083
1137
  print_block(
1084
- name="Score Analysis Options",
1085
- data=analyzer_args,
1086
- label_width=max(len(key) for key in options.keys()) + 2,
1138
+ name="Analyzer",
1139
+ data={"Analyzer": callback_class, **analyzer_args},
1140
+ label_width=max(len(key) for key in options.keys()) + 3,
1087
1141
  )
1088
1142
  print("\n" + "-" * 80)
1089
1143
 
@@ -1104,6 +1158,7 @@ def main():
1104
1158
  target_splits=splits,
1105
1159
  pad_target_edges=args.pad_target_edges,
1106
1160
  pad_fourier=args.pad_fourier,
1161
+ pad_template_filter=not args.no_filter_padding,
1107
1162
  interpolation_order=args.interpolation_order,
1108
1163
  )
1109
1164
 
@@ -1113,19 +1168,18 @@ def main():
1113
1168
  candidates[0] *= target_mask.data
1114
1169
  with warnings.catch_warnings():
1115
1170
  warnings.simplefilter("ignore", category=UserWarning)
1116
- nbytes = backend.datatype_bytes(backend._float_dtype)
1171
+ nbytes = be.datatype_bytes(be._float_dtype)
1117
1172
  dtype = np.float32 if nbytes == 4 else np.float16
1118
1173
  rot_dim = matching_data.rotations.shape[1]
1119
1174
  candidates[3] = {
1120
- x: euler_from_rotationmatrix(
1121
- np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1122
- )
1175
+ x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1123
1176
  for i, x in candidates[3].items()
1124
1177
  }
1125
1178
  candidates.append((target.origin, template.origin, template.sampling_rate, args))
1126
1179
  write_pickle(data=candidates, filename=args.output)
1127
1180
 
1128
1181
  runtime = time() - start
1182
+ print("\n" + "-" * 80)
1129
1183
  print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
1130
1184
 
1131
1185