pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@
8
8
  import os
9
9
  import argparse
10
10
  import warnings
11
- import importlib.util
12
11
  from sys import exit
13
12
  from time import time
14
13
  from typing import Tuple
@@ -22,7 +21,6 @@ from tme.matching_utils import (
22
21
  get_rotation_matrices,
23
22
  get_rotations_around_vector,
24
23
  compute_parallelization_schedule,
25
- euler_from_rotationmatrix,
26
24
  scramble_phases,
27
25
  generate_tempfile_name,
28
26
  write_pickle,
@@ -33,7 +31,7 @@ from tme.analyzer import (
33
31
  MaxScoreOverRotations,
34
32
  PeakCallerMaximumFilter,
35
33
  )
36
- from tme.backends import backend
34
+ from tme.backends import backend as be
37
35
  from tme.preprocessing import Compose
38
36
 
39
37
 
@@ -52,7 +50,7 @@ def print_block(name: str, data: dict, label_width=20) -> None:
52
50
 
53
51
  def print_entry() -> None:
54
52
  width = 80
55
- text = f" pyTME v{__version__} "
53
+ text = f" pytme v{__version__} "
56
54
  padding_total = width - len(text) - 2
57
55
  padding_left = padding_total // 2
58
56
  padding_right = padding_total - padding_left
@@ -273,7 +271,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
273
271
  return_real_fourier=True,
274
272
  )
275
273
  ctf.sampling_rate = template.sampling_rate
276
- ctf.flip_phase = not args.no_flip_phase
274
+ ctf.flip_phase = args.no_flip_phase
277
275
  ctf.amplitude_contrast = args.amplitude_contrast
278
276
  ctf.spherical_aberration = args.spherical_aberration
279
277
  ctf.acceleration_voltage = args.acceleration_voltage * 1e3
@@ -306,6 +304,12 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
306
304
  if highpass is not None:
307
305
  highpass = np.max(np.divide(template.sampling_rate, highpass))
308
306
 
307
+ try:
308
+ if args.lowpass >= args.highpass:
309
+ warnings.warn("--lowpass should be smaller than --highpass.")
310
+ except Exception:
311
+ pass
312
+
309
313
  bandpass = BandPassFilter(
310
314
  use_gaussian=args.no_pass_smooth,
311
315
  lowpass=lowpass,
@@ -313,7 +317,9 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
313
317
  sampling_rate=template.sampling_rate,
314
318
  )
315
319
  template_filter.append(bandpass)
316
- target_filter.append(bandpass)
320
+
321
+ if not args.no_filter_target:
322
+ target_filter.append(bandpass)
317
323
 
318
324
  if args.whiten_spectrum:
319
325
  whitening_filter = LinearWhiteningFilter()
@@ -335,7 +341,10 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
335
341
 
336
342
 
337
343
  def parse_args():
338
- parser = argparse.ArgumentParser(description="Perform template matching.")
344
+ parser = argparse.ArgumentParser(
345
+ description="Perform template matching.",
346
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
347
+ )
339
348
 
340
349
  io_group = parser.add_argument_group("Input / Output")
341
350
  io_group.add_argument(
@@ -405,13 +414,6 @@ def parse_args():
405
414
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
406
415
  help="Template matching scoring function.",
407
416
  )
408
- scoring_group.add_argument(
409
- "-p",
410
- dest="peak_calling",
411
- action="store_true",
412
- default=False,
413
- help="Perform peak calling instead of score aggregation.",
414
- )
415
417
 
416
418
  angular_group = parser.add_argument_group("Angular Sampling")
417
419
  angular_exclusive = angular_group.add_mutually_exclusive_group(required=True)
@@ -445,7 +447,7 @@ def parse_args():
445
447
  type=check_positive,
446
448
  default=360.0,
447
449
  required=False,
448
- help="Sampling angle along the z-axis of the cone. Defaults to 360.",
450
+ help="Sampling angle along the z-axis of the cone.",
449
451
  )
450
452
  angular_group.add_argument(
451
453
  "--axis_sampling",
@@ -513,8 +515,7 @@ def parse_args():
513
515
  required=False,
514
516
  type=float,
515
517
  default=0.85,
516
- help="Fraction of available memory that can be used. Defaults to 0.85 and is "
517
- "ignored if --ram is set",
518
+ help="Fraction of available memory to be used. Ignored if --ram is set.",
518
519
  )
519
520
  computation_group.add_argument(
520
521
  "--temp_directory",
@@ -522,7 +523,13 @@ def parse_args():
522
523
  default=None,
523
524
  help="Directory for temporary objects. Faster I/O improves runtime.",
524
525
  )
525
-
526
+ computation_group.add_argument(
527
+ "--backend",
528
+ dest="backend",
529
+ default=None,
530
+ choices=be.available_backends(),
531
+ help="[Expert] Overwrite default computation backend.",
532
+ )
526
533
  filter_group = parser.add_argument_group("Filters")
527
534
  filter_group.add_argument(
528
535
  "--lowpass",
@@ -552,9 +559,9 @@ def parse_args():
552
559
  dest="pass_format",
553
560
  type=str,
554
561
  required=False,
562
+ default="sampling_rate",
555
563
  choices=["sampling_rate", "voxel", "frequency"],
556
- help="How values passed to --lowpass and --highpass should be interpreted. "
557
- "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
564
+ help="How values passed to --lowpass and --highpass should be interpreted. ",
558
565
  )
559
566
  filter_group.add_argument(
560
567
  "--whiten_spectrum",
@@ -613,6 +620,13 @@ def parse_args():
613
620
  required=False,
614
621
  help="Analogous to --interpolation_order but for reconstruction.",
615
622
  )
623
+ filter_group.add_argument(
624
+ "--no_filter_target",
625
+ dest="no_filter_target",
626
+ action="store_true",
627
+ default=False,
628
+ help="Whether to not apply potential filters to the target.",
629
+ )
616
630
 
617
631
  ctf_group = parser.add_argument_group("Contrast Transfer Function")
618
632
  ctf_group.add_argument(
@@ -647,7 +661,7 @@ def parse_args():
647
661
  type=float,
648
662
  required=False,
649
663
  default=300,
650
- help="Acceleration voltage in kV, defaults to 300.",
664
+ help="Acceleration voltage in kV.",
651
665
  )
652
666
  ctf_group.add_argument(
653
667
  "--spherical_aberration",
@@ -663,14 +677,14 @@ def parse_args():
663
677
  type=float,
664
678
  required=False,
665
679
  default=0.07,
666
- help="Amplitude contrast, defaults to 0.07.",
680
+ help="Amplitude contrast.",
667
681
  )
668
682
  ctf_group.add_argument(
669
683
  "--no_flip_phase",
670
684
  dest="no_flip_phase",
671
685
  action="store_false",
672
686
  required=False,
673
- help="Whether the phase of the computed CTF should not be flipped.",
687
+ help="Do not perform phase-flipping CTF correction.",
674
688
  )
675
689
  ctf_group.add_argument(
676
690
  "--correct_defocus_gradient",
@@ -721,14 +735,22 @@ def parse_args():
721
735
  "for numerical stability. When working with very large targets, e.g. tomograms, "
722
736
  "it is safe to use this flag and benefit from the performance gain.",
723
737
  )
738
+ performance_group.add_argument(
739
+ "--no_filter_padding",
740
+ dest="no_filter_padding",
741
+ action="store_true",
742
+ default=False,
743
+ help="Omits padding of optional template filters. Particularly effective when "
744
+ "the target is much larger than the template. However, for fast osciliating "
745
+ "filters setting this flag can introduce aliasing effects.",
746
+ )
724
747
  performance_group.add_argument(
725
748
  "--interpolation_order",
726
749
  dest="interpolation_order",
727
750
  required=False,
728
751
  type=int,
729
752
  default=3,
730
- help="Spline interpolation used for template rotations. If less than zero "
731
- "no interpolation is performed.",
753
+ help="Spline interpolation used for rotations.",
732
754
  )
733
755
  performance_group.add_argument(
734
756
  "--use_mixed_precision",
@@ -755,7 +777,20 @@ def parse_args():
755
777
  default=0,
756
778
  help="Minimum template matching scores to consider for analysis.",
757
779
  )
758
-
780
+ analyzer_group.add_argument(
781
+ "-p",
782
+ dest="peak_calling",
783
+ action="store_true",
784
+ default=False,
785
+ help="Perform peak calling instead of score aggregation.",
786
+ )
787
+ analyzer_group.add_argument(
788
+ "--number_of_peaks",
789
+ dest="number_of_peaks",
790
+ action="store_true",
791
+ default=1000,
792
+ help="Number of peaks to call, 1000 by default.",
793
+ )
759
794
  args = parser.parse_args()
760
795
  args.version = __version__
761
796
 
@@ -924,44 +959,56 @@ def main():
924
959
  template.data, noise_proportion=1.0, normalize_power=True
925
960
  )
926
961
 
927
- available_memory = backend.get_available_memory()
962
+ # Determine suitable backend for the selected operation
963
+ available_backends = be.available_backends()
964
+ if args.backend is not None:
965
+ req_backend = args.backend
966
+ if req_backend not in available_backends:
967
+ raise ValueError("Requested backend is not available.")
968
+ available_backends = [req_backend]
969
+
970
+ be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
928
971
  if args.use_gpu:
929
972
  args.cores = len(args.gpu_indices)
930
- has_torch = importlib.util.find_spec("torch") is not None
931
- has_cupy = importlib.util.find_spec("cupy") is not None
932
-
933
- if not has_torch and not has_cupy:
934
- raise ValueError(
935
- "Found neither CuPy nor PyTorch installation. You need to install"
936
- " either to enable GPU support."
937
- )
973
+ be_selection = ("pytorch", "cupy", "jax")
974
+ if args.use_mixed_precision:
975
+ be_selection = tuple(x for x in be_selection if x in ("cupy", "numpyfftw"))
938
976
 
939
- if args.peak_calling:
940
- preferred_backend = "pytorch"
941
- if not has_torch:
942
- preferred_backend = "cupy"
943
- backend.change_backend(backend_name=preferred_backend, device="cuda")
944
- else:
945
- preferred_backend = "cupy"
946
- if not has_cupy:
947
- preferred_backend = "pytorch"
948
- backend.change_backend(backend_name=preferred_backend, device="cuda")
949
- if args.use_mixed_precision and preferred_backend == "pytorch":
977
+ available_backends = [x for x in available_backends if x in be_selection]
978
+ if args.peak_calling:
979
+ if "jax" in available_backends:
980
+ available_backends.remove("jax")
981
+ if args.use_gpu and "pytorch" in available_backends:
982
+ available_backends = ("pytorch",)
983
+ if args.interpolation_order == 3:
950
984
  raise NotImplementedError(
951
- "pytorch backend does not yet support mixed precision."
952
- " Consider installing CuPy to enable this feature."
985
+ "Pytorch does not support --interpolation_order 3, 1 is supported."
953
986
  )
954
- elif args.use_mixed_precision:
955
- backend.change_backend(
956
- backend_name="cupy",
957
- default_dtype=backend._array_backend.float16,
958
- complex_dtype=backend._array_backend.complex64,
959
- default_dtype_int=backend._array_backend.int16,
960
- )
961
- available_memory = backend.get_available_memory() * args.cores
962
- if preferred_backend == "pytorch" and args.interpolation_order == 3:
963
- args.interpolation_order = 1
987
+ # dim_match = len(template.shape) == len(target.shape) <= 3
988
+ # if dim_match and args.use_gpu and "jax" in available_backends:
989
+ # args.interpolation_order = 1
990
+ # available_backends = ["jax"]
964
991
 
992
+ backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
993
+ if args.use_gpu:
994
+ backend_preference = ("cupy", "pytorch", "jax")
995
+ for pref in backend_preference:
996
+ if pref not in available_backends:
997
+ continue
998
+ be.change_backend(pref)
999
+ if pref == "pytorch":
1000
+ be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
1001
+
1002
+ if args.use_mixed_precision:
1003
+ be.change_backend(
1004
+ backend_name=pref,
1005
+ default_dtype=be._array_backend.float16,
1006
+ complex_dtype=be._array_backend.complex64,
1007
+ default_dtype_int=be._array_backend.int16,
1008
+ )
1009
+ break
1010
+
1011
+ available_memory = be.get_available_memory() * be.device_count()
965
1012
  if args.memory is None:
966
1013
  args.memory = int(args.memory_scaling * available_memory)
967
1014
 
@@ -978,17 +1025,15 @@ def main():
978
1025
  rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
979
1026
  )
980
1027
 
981
- template_filter, target_filter = setup_filter(args, template, target)
982
- matching_data.template_filter = template_filter
983
- matching_data.target_filter = target_filter
1028
+ matching_data.template_filter, matching_data.target_filter = setup_filter(
1029
+ args, template, target
1030
+ )
984
1031
 
985
1032
  template_box = matching_data._output_template_shape
986
1033
  if not args.pad_fourier:
987
- template_box = np.ones(len(template_box), dtype=int)
1034
+ template_box = tuple(0 for _ in range(len(template_box)))
988
1035
 
989
- target_padding = np.zeros(
990
- (backend.size(matching_data._output_template_shape)), dtype=int
991
- )
1036
+ target_padding = tuple(0 for _ in range(len(template_box)))
992
1037
  if args.pad_target_edges:
993
1038
  target_padding = matching_data._output_template_shape
994
1039
 
@@ -1001,10 +1046,10 @@ def main():
1001
1046
  split_only_outer=args.use_gpu,
1002
1047
  matching_method=args.score,
1003
1048
  analyzer_method=callback_class.__name__,
1004
- backend=backend._backend_name,
1005
- float_nbytes=backend.datatype_bytes(backend._float_dtype),
1006
- complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
1007
- integer_nbytes=backend.datatype_bytes(backend._int_dtype),
1049
+ backend=be._backend_name,
1050
+ float_nbytes=be.datatype_bytes(be._float_dtype),
1051
+ complex_nbytes=be.datatype_bytes(be._complex_dtype),
1052
+ integer_nbytes=be.datatype_bytes(be._int_dtype),
1008
1053
  )
1009
1054
 
1010
1055
  if splits is None:
@@ -1021,27 +1066,36 @@ def main():
1021
1066
  )
1022
1067
  gpus_used = 0 if args.gpu_indices is None else len(args.gpu_indices)
1023
1068
  options = {
1024
- "CPU Cores": args.cores,
1025
- "Run on GPU": f"{args.use_gpu} [N={gpus_used}]",
1026
- "Use Mixed Precision": args.use_mixed_precision,
1027
- "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1028
- "Temporary Directory": args.temp_directory,
1069
+ "Angular Sampling": f"{args.angular_sampling}"
1070
+ f" [{matching_data.rotations.shape[0]} rotations]",
1071
+ "Center Template": not args.no_centering,
1072
+ "Scramble Template": args.scramble_phases,
1073
+ "Invert Contrast": args.invert_target_contrast,
1029
1074
  "Extend Fourier Grid": not args.no_fourier_padding,
1030
1075
  "Extend Target Edges": not args.no_edge_padding,
1031
1076
  "Interpolation Order": args.interpolation_order,
1032
- "Score": f"{args.score}",
1033
1077
  "Setup Function": f"{get_func_fullname(matching_setup)}",
1034
1078
  "Scoring Function": f"{get_func_fullname(matching_score)}",
1035
- "Angular Sampling": f"{args.angular_sampling}"
1036
- f" [{matching_data.rotations.shape[0]} rotations]",
1037
- "Scramble Template": args.scramble_phases,
1038
- "Target Splits": f"{target_split} [N={n_splits}]",
1039
1079
  }
1040
1080
 
1041
1081
  print_block(
1042
- name="Template Matching Options",
1082
+ name="Template Matching",
1043
1083
  data=options,
1044
- label_width=max(len(key) for key in options.keys()) + 2,
1084
+ label_width=max(len(key) for key in options.keys()) + 3,
1085
+ )
1086
+
1087
+ compute_options = {
1088
+ "Backend": be._BACKEND_REGISTRY[be._backend_name],
1089
+ "Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
1090
+ "Use Mixed Precision": args.use_mixed_precision,
1091
+ "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1092
+ "Temporary Directory": args.temp_directory,
1093
+ "Target Splits": f"{target_split} [N={n_splits}]",
1094
+ }
1095
+ print_block(
1096
+ name="Computation",
1097
+ data=compute_options,
1098
+ label_width=max(len(key) for key in options.keys()) + 3,
1045
1099
  )
1046
1100
 
1047
1101
  filter_args = {
@@ -1059,7 +1113,7 @@ def main():
1059
1113
  filter_args["CTF File"] = args.ctf_file
1060
1114
  filter_args["Defocus"] = args.defocus
1061
1115
  filter_args["Phase Shift"] = args.phase_shift
1062
- filter_args["No Flip Phase"] = args.no_flip_phase
1116
+ filter_args["Flip Phase"] = args.no_flip_phase
1063
1117
  filter_args["Acceleration Voltage"] = args.acceleration_voltage
1064
1118
  filter_args["Spherical Aberration"] = args.spherical_aberration
1065
1119
  filter_args["Amplitude Contrast"] = args.amplitude_contrast
@@ -1070,20 +1124,20 @@ def main():
1070
1124
  print_block(
1071
1125
  name="Filters",
1072
1126
  data=filter_args,
1073
- label_width=max(len(key) for key in options.keys()) + 2,
1127
+ label_width=max(len(key) for key in options.keys()) + 3,
1074
1128
  )
1075
1129
 
1076
1130
  analyzer_args = {
1077
1131
  "score_threshold": args.score_threshold,
1078
- "number_of_peaks": 1000,
1079
- "convolution_mode": "valid",
1132
+ "number_of_peaks": args.number_of_peaks,
1133
+ "min_distance": max(template.shape) // 2,
1134
+ "min_boundary_distance": max(template.shape) // 2,
1080
1135
  "use_memmap": args.use_memmap,
1081
1136
  }
1082
- analyzer_args = {"Analyzer": callback_class, **analyzer_args}
1083
1137
  print_block(
1084
- name="Score Analysis Options",
1085
- data=analyzer_args,
1086
- label_width=max(len(key) for key in options.keys()) + 2,
1138
+ name="Analyzer",
1139
+ data={"Analyzer": callback_class, **analyzer_args},
1140
+ label_width=max(len(key) for key in options.keys()) + 3,
1087
1141
  )
1088
1142
  print("\n" + "-" * 80)
1089
1143
 
@@ -1104,6 +1158,7 @@ def main():
1104
1158
  target_splits=splits,
1105
1159
  pad_target_edges=args.pad_target_edges,
1106
1160
  pad_fourier=args.pad_fourier,
1161
+ pad_template_filter=not args.no_filter_padding,
1107
1162
  interpolation_order=args.interpolation_order,
1108
1163
  )
1109
1164
 
@@ -1113,19 +1168,18 @@ def main():
1113
1168
  candidates[0] *= target_mask.data
1114
1169
  with warnings.catch_warnings():
1115
1170
  warnings.simplefilter("ignore", category=UserWarning)
1116
- nbytes = backend.datatype_bytes(backend._float_dtype)
1171
+ nbytes = be.datatype_bytes(be._float_dtype)
1117
1172
  dtype = np.float32 if nbytes == 4 else np.float16
1118
1173
  rot_dim = matching_data.rotations.shape[1]
1119
1174
  candidates[3] = {
1120
- x: euler_from_rotationmatrix(
1121
- np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1122
- )
1175
+ x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1123
1176
  for i, x in candidates[3].items()
1124
1177
  }
1125
1178
  candidates.append((target.origin, template.origin, template.sampling_rate, args))
1126
1179
  write_pickle(data=candidates, filename=args.output)
1127
1180
 
1128
1181
  runtime = time() - start
1182
+ print("\n" + "-" * 80)
1129
1183
  print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
1130
1184
 
1131
1185
 
@@ -8,9 +8,8 @@
8
8
  import argparse
9
9
  from sys import exit
10
10
  from os import getcwd
11
- from os.path import join, abspath
12
11
  from typing import List, Tuple
13
- from os.path import splitext
12
+ from os.path import join, abspath, splitext
14
13
 
15
14
  import numpy as np
16
15
  from numpy.typing import NDArray
@@ -26,6 +25,7 @@ from tme.analyzer import (
26
25
  )
27
26
  from tme.matching_utils import (
28
27
  load_pickle,
28
+ centered_mask,
29
29
  euler_to_rotationmatrix,
30
30
  euler_from_rotationmatrix,
31
31
  )
@@ -41,9 +41,7 @@ PEAK_CALLERS = {
41
41
 
42
42
 
43
43
  def parse_args():
44
- parser = argparse.ArgumentParser(
45
- description="Peak Calling for Template Matching Outputs"
46
- )
44
+ parser = argparse.ArgumentParser(description="Analyze Template Matching Outputs")
47
45
 
48
46
  input_group = parser.add_argument_group("Input")
49
47
  output_group = parser.add_argument_group("Output")
@@ -56,6 +54,13 @@ def parse_args():
56
54
  nargs="+",
57
55
  help="Path to the output of match_template.py.",
58
56
  )
57
+ input_group.add_argument(
58
+ "--background_file",
59
+ required=False,
60
+ nargs="+",
61
+ help="Path to an output of match_template.py used for normalization. "
62
+ "For instance from --scramble_phases or a different template.",
63
+ )
59
64
  input_group.add_argument(
60
65
  "--target_mask",
61
66
  required=False,
@@ -87,7 +92,7 @@ def parse_args():
87
92
  "average",
88
93
  ],
89
94
  default="orientations",
90
- help="Available output formats:"
95
+ help="Available output formats: "
91
96
  "orientations (translation, rotation, and score), "
92
97
  "alignment (aligned template to target based on orientations), "
93
98
  "extraction (extract regions around peaks from targets, i.e. subtomograms), "
@@ -206,6 +211,15 @@ def parse_args():
206
211
  elif args.number_of_peaks is None:
207
212
  args.number_of_peaks = 1000
208
213
 
214
+ if args.background_file is None:
215
+ args.background_file = [None]
216
+ if len(args.background_file) == 1:
217
+ args.background_file = args.background_file * len(args.input_file)
218
+ elif len(args.background_file) not in (0, len(args.input_file)):
219
+ raise ValueError(
220
+ "--background_file needs to be specified once or for each --input_file."
221
+ )
222
+
209
223
  return args
210
224
 
211
225
 
@@ -233,8 +247,8 @@ def load_template(
233
247
  return template, center, translation, template_is_density
234
248
 
235
249
 
236
- def merge_outputs(data, filepaths: List[str], args):
237
- if len(filepaths) == 0:
250
+ def merge_outputs(data, foreground_paths: List[str], background_paths: List[str], args):
251
+ if len(foreground_paths) == 0:
238
252
  return data, 1
239
253
 
240
254
  if data[0].ndim != data[2].ndim:
@@ -275,8 +289,11 @@ def merge_outputs(data, filepaths: List[str], args):
275
289
 
276
290
  entities = np.zeros_like(data[0])
277
291
  data[0] = _norm_scores(data=data, args=args)
278
- for index, filepath in enumerate(filepaths):
279
- new_scores = _norm_scores(data=load_pickle(filepath), args=args)
292
+ for index, filepath in enumerate(foreground_paths):
293
+ new_scores = _norm_scores(
294
+ data=load_match_template_output(filepath, background_paths[index]),
295
+ args=args,
296
+ )
280
297
  indices = new_scores > data[0]
281
298
  entities[indices] = index + 1
282
299
  data[0][indices] = new_scores[indices]
@@ -284,9 +301,18 @@ def merge_outputs(data, filepaths: List[str], args):
284
301
  return data, entities
285
302
 
286
303
 
304
+ def load_match_template_output(foreground_path, background_path):
305
+ data = load_pickle(foreground_path)
306
+ if background_path is not None:
307
+ data_background = load_pickle(background_path)
308
+ data[0] = (data[0] - data_background[0]) / (1 - data_background[0])
309
+ np.fmax(data[0], 0, out=data[0])
310
+ return data
311
+
312
+
287
313
  def main():
288
314
  args = parse_args()
289
- data = load_pickle(args.input_file[0])
315
+ data = load_match_template_output(args.input_file[0], args.background_file[0])
290
316
 
291
317
  target_origin, _, sampling_rate, cli_args = data[-1]
292
318
 
@@ -326,7 +352,12 @@ def main():
326
352
 
327
353
  entities = None
328
354
  if len(args.input_file) > 1:
329
- data, entities = merge_outputs(data=data, filepaths=args.input_file, args=args)
355
+ data, entities = merge_outputs(
356
+ data=data,
357
+ foreground_paths=args.input_file,
358
+ background_paths=args.background_file,
359
+ args=args,
360
+ )
330
361
 
331
362
  orientations = args.orientations
332
363
  if orientations is None:
@@ -339,24 +370,27 @@ def main():
339
370
  target_mask = Density.from_file(args.target_mask)
340
371
  scores = scores * target_mask.data
341
372
 
342
- if args.n_false_positives is not None:
343
- args.n_false_positives = max(args.n_false_positives, 1)
344
- cropped_shape = np.subtract(
345
- scores.shape, np.multiply(args.min_boundary_distance, 2)
346
- ).astype(int)
373
+ cropped_shape = np.subtract(
374
+ scores.shape, np.multiply(args.min_boundary_distance, 2)
375
+ ).astype(int)
347
376
 
348
- cropped_shape = tuple(
377
+ if args.min_boundary_distance > 0:
378
+ scores = centered_mask(scores, new_shape=cropped_shape)
379
+
380
+ if args.n_false_positives is not None:
381
+ # Rickgauer et al. 2017
382
+ cropped_slice = tuple(
349
383
  slice(
350
384
  int(args.min_boundary_distance),
351
385
  int(x - args.min_boundary_distance),
352
386
  )
353
387
  for x in scores.shape
354
388
  )
355
- # Rickgauer et al. 2017
356
- n_correlations = np.size(scores[cropped_shape]) * len(rotation_mapping)
389
+ args.n_false_positives = max(args.n_false_positives, 1)
390
+ n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
357
391
  minimum_score = np.multiply(
358
392
  erfcinv(2 * args.n_false_positives / n_correlations),
359
- np.sqrt(2) * np.std(scores[cropped_shape]),
393
+ np.sqrt(2) * np.std(scores[cropped_slice]),
360
394
  )
361
395
  print(f"Determined minimum score cutoff: {minimum_score}.")
362
396
  minimum_score = max(minimum_score, 0)
@@ -371,6 +405,8 @@ def main():
371
405
  "min_distance": args.min_distance,
372
406
  "min_boundary_distance": args.min_boundary_distance,
373
407
  "batch_dims": args.batch_dims,
408
+ "minimum_score": args.minimum_score,
409
+ "maximum_score": args.maximum_score,
374
410
  }
375
411
 
376
412
  peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
@@ -380,7 +416,6 @@ def main():
380
416
  mask=template.data,
381
417
  rotation_mapping=rotation_mapping,
382
418
  rotation_array=rotation_array,
383
- minimum_score=args.minimum_score,
384
419
  )
385
420
  candidates = peak_caller.merge(
386
421
  candidates=[tuple(peak_caller)], **peak_caller_kwargs
@@ -388,10 +423,16 @@ def main():
388
423
  if len(candidates) == 0:
389
424
  candidates = [[], [], [], []]
390
425
  print("Found no peaks, consider changing peak calling parameters.")
391
- exit(0)
426
+ exit(-1)
392
427
 
393
428
  for translation, _, score, detail in zip(*candidates):
394
- rotations.append(rotation_mapping[rotation_array[tuple(translation)]])
429
+ rotation_index = rotation_array[tuple(translation)]
430
+ rotation = rotation_mapping.get(
431
+ rotation_index, np.zeros(template.data.ndim, int)
432
+ )
433
+ if rotation.ndim == 2:
434
+ rotation = euler_from_rotationmatrix(rotation)
435
+ rotations.append(rotation)
395
436
 
396
437
  else:
397
438
  candidates = data
@@ -430,7 +471,7 @@ def main():
430
471
  )
431
472
  exit(-1)
432
473
  orientations.translations = peak_caller.oversample_peaks(
433
- score_space=data[0],
474
+ scores=data[0],
434
475
  peak_positions=orientations.translations,
435
476
  oversampling_factor=args.peak_oversampling,
436
477
  )
@@ -570,7 +611,7 @@ def main():
570
611
  return_orientations=True,
571
612
  )
572
613
  out = np.zeros_like(template.data)
573
- out = np.zeros(np.multiply(template.shape, 2).astype(int))
614
+ # out = np.zeros(np.multiply(template.shape, 2).astype(int))
574
615
  for index in range(len(cand_slices)):
575
616
  from scipy.spatial.transform import Rotation
576
617