pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/match_template.py +91 -142
  2. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/postprocess.py +20 -29
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +6 -9
  5. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/METADATA +11 -10
  6. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/RECORD +33 -32
  7. pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
  8. scripts/match_template.py +91 -142
  9. scripts/postprocess.py +20 -29
  10. scripts/preprocess.py +95 -56
  11. scripts/preprocessor_gui.py +6 -9
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +9 -6
  14. tme/backends/__init__.py +1 -1
  15. tme/backends/_jax_utils.py +10 -8
  16. tme/backends/cupy_backend.py +2 -7
  17. tme/backends/jax_backend.py +34 -20
  18. tme/backends/npfftw_backend.py +3 -2
  19. tme/backends/pytorch_backend.py +10 -7
  20. tme/density.py +15 -8
  21. tme/extensions.cpython-311-darwin.so +0 -0
  22. tme/matching_data.py +24 -17
  23. tme/matching_exhaustive.py +36 -19
  24. tme/matching_scores.py +5 -2
  25. tme/matching_utils.py +7 -2
  26. tme/orientations.py +26 -9
  27. tme/preprocessing/composable_filter.py +7 -4
  28. tme/preprocessing/tilt_series.py +10 -32
  29. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  30. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  31. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/WHEEL +0 -0
  32. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  33. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  #!python
2
- """ CLI interface for basic pyTME template matching functions.
2
+ """ CLI for basic pyTME template matching functions.
3
3
 
4
4
  Copyright (c) 2023 European Molecular Biology Laboratory
5
5
 
@@ -22,7 +22,6 @@ from tme.matching_utils import (
22
22
  get_rotations_around_vector,
23
23
  compute_parallelization_schedule,
24
24
  scramble_phases,
25
- generate_tempfile_name,
26
25
  write_pickle,
27
26
  )
28
27
  from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
@@ -99,7 +98,9 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
99
98
  f"Expected shape of {mask_path} was {mask_target.shape},"
100
99
  f" got f{mask.shape}"
101
100
  )
102
- if not np.allclose(mask.sampling_rate, mask_target.sampling_rate):
101
+ if not np.allclose(
102
+ np.round(mask.sampling_rate, 2), np.round(mask_target.sampling_rate, 2)
103
+ ):
103
104
  raise ValueError(
104
105
  f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
105
106
  f", got f{mask.sampling_rate}"
@@ -107,50 +108,6 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
107
108
  return mask
108
109
 
109
110
 
110
- def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
111
- """
112
- Crop the provided data and mask to a smaller box based on a cutoff value.
113
-
114
- Parameters
115
- ----------
116
- data : Density
117
- The data that should be cropped.
118
- cutoff : float
119
- The threshold value to determine which parts of the data should be kept.
120
- data_mask : Density, optional
121
- A mask for the data that should be cropped.
122
-
123
- Returns
124
- -------
125
- bool
126
- Returns True if the data was adjusted (cropped), otherwise returns False.
127
-
128
- Notes
129
- -----
130
- Cropping is performed in place.
131
- """
132
- if cutoff is None:
133
- return False
134
-
135
- box = data.trim_box(cutoff=cutoff)
136
- box_mask = box
137
- if data_mask is not None:
138
- box_mask = data_mask.trim_box(cutoff=cutoff)
139
- box = tuple(
140
- slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
141
- for arr, mask in zip(box, box_mask)
142
- )
143
- if box == tuple(slice(0, x) for x in data.shape):
144
- return False
145
-
146
- data.adjust_box(box)
147
-
148
- if data_mask:
149
- data_mask.adjust_box(box)
150
-
151
- return True
152
-
153
-
154
111
  def parse_rotation_logic(args, ndim):
155
112
  if args.angular_sampling is not None:
156
113
  rotations = get_rotation_matrices(
@@ -175,8 +132,52 @@ def parse_rotation_logic(args, ndim):
175
132
  return rotations
176
133
 
177
134
 
178
- # TODO: Think about whether wedge mask should also be added to target
179
- # For now leave it at the cost of incorrect upper bound on the scores
135
+ def compute_schedule(
136
+ args,
137
+ target: Density,
138
+ matching_data: MatchingData,
139
+ callback_class,
140
+ pad_edges: bool = False,
141
+ ):
142
+ # User requested target padding
143
+ if args.pad_edges is True:
144
+ pad_edges = True
145
+ template_box = matching_data._output_template_shape
146
+ if not args.pad_fourier:
147
+ template_box = tuple(0 for _ in range(len(template_box)))
148
+
149
+ target_padding = tuple(0 for _ in range(len(template_box)))
150
+ if pad_edges:
151
+ target_padding = matching_data._output_template_shape
152
+
153
+ splits, schedule = compute_parallelization_schedule(
154
+ shape1=target.shape,
155
+ shape2=tuple(int(x) for x in template_box),
156
+ shape1_padding=tuple(int(x) for x in target_padding),
157
+ max_cores=args.cores,
158
+ max_ram=args.memory,
159
+ split_only_outer=args.use_gpu,
160
+ matching_method=args.score,
161
+ analyzer_method=callback_class.__name__,
162
+ backend=be._backend_name,
163
+ float_nbytes=be.datatype_bytes(be._float_dtype),
164
+ complex_nbytes=be.datatype_bytes(be._complex_dtype),
165
+ integer_nbytes=be.datatype_bytes(be._int_dtype),
166
+ )
167
+
168
+ if splits is None:
169
+ print(
170
+ "Found no suitable parallelization schedule. Consider increasing"
171
+ " available RAM or decreasing number of cores."
172
+ )
173
+ exit(-1)
174
+ n_splits = np.prod(list(splits.values()))
175
+ if pad_edges is False and n_splits > 1:
176
+ args.pad_edges = True
177
+ return compute_schedule(args, target, matching_data, callback_class, True)
178
+ return splits, schedule
179
+
180
+
180
181
  def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
181
182
  from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
182
183
  from tme.preprocessing.tilt_series import (
@@ -232,18 +233,23 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
232
233
  weight_wedge=args.tilt_weighting == "angle",
233
234
  create_continuous_wedge=create_continuous_wedge,
234
235
  )
236
+ wedge_target = WedgeReconstructed(
237
+ angles=tilt_angles,
238
+ weight_wedge=False,
239
+ create_continuous_wedge=create_continuous_wedge,
240
+ )
241
+ target_filter.append(wedge_target)
235
242
 
236
243
  wedge.opening_axis = args.wedge_axes[0]
237
244
  wedge.tilt_axis = args.wedge_axes[1]
238
245
  wedge.sampling_rate = template.sampling_rate
239
246
  template_filter.append(wedge)
240
247
  if not isinstance(wedge, WedgeReconstructed):
241
- template_filter.append(
242
- ReconstructFromTilt(
243
- reconstruction_filter=args.reconstruction_filter,
244
- interpolation_order=args.reconstruction_interpolation_order,
245
- )
248
+ reconstruction_filter = ReconstructFromTilt(
249
+ reconstruction_filter=args.reconstruction_filter,
250
+ interpolation_order=args.reconstruction_interpolation_order,
246
251
  )
252
+ template_filter.append(reconstruction_filter)
247
253
 
248
254
  if args.ctf_file is not None or args.defocus is not None:
249
255
  from tme.preprocessing.tilt_series import CTF
@@ -393,8 +399,7 @@ def parse_args():
393
399
  dest="invert_target_contrast",
394
400
  action="store_true",
395
401
  default=False,
396
- help="Invert the target's contrast and rescale linearly between zero and one. "
397
- "This option is intended for targets where templates to-be-matched have "
402
+ help="Invert the target's contrast for cases where templates to-be-matched have "
398
403
  "negative values, e.g. tomograms.",
399
404
  )
400
405
  io_group.add_argument(
@@ -696,22 +701,6 @@ def parse_args():
696
701
  )
697
702
 
698
703
  performance_group = parser.add_argument_group("Performance")
699
- performance_group.add_argument(
700
- "--cutoff_target",
701
- dest="cutoff_target",
702
- type=float,
703
- required=False,
704
- default=None,
705
- help="Target contour level (used for cropping).",
706
- )
707
- performance_group.add_argument(
708
- "--cutoff_template",
709
- dest="cutoff_template",
710
- type=float,
711
- required=False,
712
- default=None,
713
- help="Template contour level (used for cropping).",
714
- )
715
704
  performance_group.add_argument(
716
705
  "--no_centering",
717
706
  dest="no_centering",
@@ -719,30 +708,28 @@ def parse_args():
719
708
  help="Assumes the template is already centered and omits centering.",
720
709
  )
721
710
  performance_group.add_argument(
722
- "--no_edge_padding",
723
- dest="no_edge_padding",
711
+ "--pad_edges",
712
+ dest="pad_edges",
724
713
  action="store_true",
725
714
  default=False,
726
- help="Whether to not pad the edges of the target. Can be set if the target"
727
- " has a well defined bounding box, e.g. a masked reconstruction.",
715
+ help="Whether to pad the edges of the target. Useful if the target does not "
716
+ "a well-defined bounding box. Defaults to True if splitting is required.",
728
717
  )
729
718
  performance_group.add_argument(
730
- "--no_fourier_padding",
731
- dest="no_fourier_padding",
719
+ "--pad_fourier",
720
+ dest="pad_fourier",
732
721
  action="store_true",
733
722
  default=False,
734
723
  help="Whether input arrays should not be zero-padded to full convolution shape "
735
- "for numerical stability. When working with very large targets, e.g. tomograms, "
736
- "it is safe to use this flag and benefit from the performance gain.",
724
+ "for numerical stability. Typically only useful when working with small data.",
737
725
  )
738
726
  performance_group.add_argument(
739
- "--no_filter_padding",
740
- dest="no_filter_padding",
727
+ "--pad_filter",
728
+ dest="pad_filter",
741
729
  action="store_true",
742
730
  default=False,
743
- help="Omits padding of optional template filters. Particularly effective when "
744
- "the target is much larger than the template. However, for fast osciliating "
745
- "filters setting this flag can introduce aliasing effects.",
731
+ help="Pads the filter to the shape of the target. Particularly useful for fast "
732
+ "oscilating filters to avoid aliasing effects.",
746
733
  )
747
734
  performance_group.add_argument(
748
735
  "--interpolation_order",
@@ -805,9 +792,6 @@ def parse_args():
805
792
 
806
793
  os.environ["TMPDIR"] = args.temp_directory
807
794
 
808
- args.pad_target_edges = not args.no_edge_padding
809
- args.pad_fourier = not args.no_fourier_padding
810
-
811
795
  if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
812
796
  raise ValueError(
813
797
  f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
@@ -862,7 +846,9 @@ def main():
862
846
  )
863
847
 
864
848
  if target.sampling_rate.size == template.sampling_rate.size:
865
- if not np.allclose(target.sampling_rate, template.sampling_rate):
849
+ if not np.allclose(
850
+ np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
851
+ ):
866
852
  print(
867
853
  f"Resampling template to {target.sampling_rate}. "
868
854
  "Consider providing a template with the same sampling rate as the target."
@@ -877,9 +863,6 @@ def main():
877
863
  )
878
864
 
879
865
  initial_shape = target.shape
880
- is_cropped = crop_data(
881
- data=target, data_mask=target_mask, cutoff=args.cutoff_target
882
- )
883
866
  print_block(
884
867
  name="Target",
885
868
  data={
@@ -888,13 +871,6 @@ def main():
888
871
  "Final Shape": target.shape,
889
872
  },
890
873
  )
891
- if is_cropped:
892
- args.target = generate_tempfile_name(suffix=".mrc")
893
- target.to_file(args.target)
894
-
895
- if target_mask:
896
- args.target_mask = generate_tempfile_name(suffix=".mrc")
897
- target_mask.to_file(args.target_mask)
898
874
 
899
875
  if target_mask:
900
876
  print_block(
@@ -907,8 +883,6 @@ def main():
907
883
  )
908
884
 
909
885
  initial_shape = template.shape
910
- _ = crop_data(data=template, data_mask=template_mask, cutoff=args.cutoff_template)
911
-
912
886
  translation = np.zeros(len(template.shape), dtype=np.float32)
913
887
  if not args.no_centering:
914
888
  template, translation = template.centered(0)
@@ -956,7 +930,7 @@ def main():
956
930
 
957
931
  if args.scramble_phases:
958
932
  template.data = scramble_phases(
959
- template.data, noise_proportion=1.0, normalize_power=True
933
+ template.data, noise_proportion=1.0, normalize_power=False
960
934
  )
961
935
 
962
936
  # Determine suitable backend for the selected operation
@@ -980,10 +954,7 @@ def main():
980
954
  available_backends.remove("jax")
981
955
  if args.use_gpu and "pytorch" in available_backends:
982
956
  available_backends = ("pytorch",)
983
- if args.interpolation_order == 3:
984
- raise NotImplementedError(
985
- "Pytorch does not support --interpolation_order 3, 1 is supported."
986
- )
957
+
987
958
  # dim_match = len(template.shape) == len(target.shape) <= 3
988
959
  # if dim_match and args.use_gpu and "jax" in available_backends:
989
960
  # args.interpolation_order = 1
@@ -1008,6 +979,12 @@ def main():
1008
979
  )
1009
980
  break
1010
981
 
982
+ if pref == "pytorch" and args.interpolation_order == 3:
983
+ warnings.warn(
984
+ "Pytorch does not support --interpolation_order 3, setting it to 1."
985
+ )
986
+ args.interpolation_order = 1
987
+
1011
988
  available_memory = be.get_available_memory() * be.device_count()
1012
989
  if args.memory is None:
1013
990
  args.memory = int(args.memory_scaling * available_memory)
@@ -1025,41 +1002,13 @@ def main():
1025
1002
  rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
1026
1003
  )
1027
1004
 
1005
+ matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
1028
1006
  matching_data.template_filter, matching_data.target_filter = setup_filter(
1029
1007
  args, template, target
1030
1008
  )
1031
1009
 
1032
- template_box = matching_data._output_template_shape
1033
- if not args.pad_fourier:
1034
- template_box = tuple(0 for _ in range(len(template_box)))
1035
-
1036
- target_padding = tuple(0 for _ in range(len(template_box)))
1037
- if args.pad_target_edges:
1038
- target_padding = matching_data._output_template_shape
1039
-
1040
- splits, schedule = compute_parallelization_schedule(
1041
- shape1=target.shape,
1042
- shape2=tuple(int(x) for x in template_box),
1043
- shape1_padding=tuple(int(x) for x in target_padding),
1044
- max_cores=args.cores,
1045
- max_ram=args.memory,
1046
- split_only_outer=args.use_gpu,
1047
- matching_method=args.score,
1048
- analyzer_method=callback_class.__name__,
1049
- backend=be._backend_name,
1050
- float_nbytes=be.datatype_bytes(be._float_dtype),
1051
- complex_nbytes=be.datatype_bytes(be._complex_dtype),
1052
- integer_nbytes=be.datatype_bytes(be._int_dtype),
1053
- )
1054
-
1055
- if splits is None:
1056
- print(
1057
- "Found no suitable parallelization schedule. Consider increasing"
1058
- " available RAM or decreasing number of cores."
1059
- )
1060
- exit(-1)
1010
+ splits, schedule = compute_schedule(args, target, matching_data, callback_class)
1061
1011
 
1062
- matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
1063
1012
  n_splits = np.prod(list(splits.values()))
1064
1013
  target_split = ", ".join(
1065
1014
  [":".join([str(x) for x in axis]) for axis in splits.items()]
@@ -1071,8 +1020,8 @@ def main():
1071
1020
  "Center Template": not args.no_centering,
1072
1021
  "Scramble Template": args.scramble_phases,
1073
1022
  "Invert Contrast": args.invert_target_contrast,
1074
- "Extend Fourier Grid": not args.no_fourier_padding,
1075
- "Extend Target Edges": not args.no_edge_padding,
1023
+ "Extend Fourier Grid": args.pad_fourier,
1024
+ "Extend Target Edges": args.pad_edges,
1076
1025
  "Interpolation Order": args.interpolation_order,
1077
1026
  "Setup Function": f"{get_func_fullname(matching_setup)}",
1078
1027
  "Scoring Function": f"{get_func_fullname(matching_score)}",
@@ -1108,6 +1057,7 @@ def main():
1108
1057
  "Tilt Angles": args.tilt_angles,
1109
1058
  "Tilt Weighting": args.tilt_weighting,
1110
1059
  "Reconstruction Filter": args.reconstruction_filter,
1060
+ "Extend Filter Grid": args.pad_filter,
1111
1061
  }
1112
1062
  if args.ctf_file is not None or args.defocus is not None:
1113
1063
  filter_args["CTF File"] = args.ctf_file
@@ -1130,8 +1080,7 @@ def main():
1130
1080
  analyzer_args = {
1131
1081
  "score_threshold": args.score_threshold,
1132
1082
  "number_of_peaks": args.number_of_peaks,
1133
- "min_distance": max(template.shape) // 2,
1134
- "min_boundary_distance": max(template.shape) // 2,
1083
+ "min_distance": max(template.shape) // 3,
1135
1084
  "use_memmap": args.use_memmap,
1136
1085
  }
1137
1086
  print_block(
@@ -1156,9 +1105,9 @@ def main():
1156
1105
  callback_class=callback_class,
1157
1106
  callback_class_args=analyzer_args,
1158
1107
  target_splits=splits,
1159
- pad_target_edges=args.pad_target_edges,
1108
+ pad_target_edges=args.pad_edges,
1160
1109
  pad_fourier=args.pad_fourier,
1161
- pad_template_filter=not args.no_filter_padding,
1110
+ pad_template_filter=args.pad_filter,
1162
1111
  interpolation_order=args.interpolation_order,
1163
1112
  )
1164
1113
 
@@ -509,19 +509,7 @@ def main():
509
509
 
510
510
  target = Density.from_file(cli_args.target)
511
511
  if args.invert_target_contrast:
512
- if args.output_format == "relion":
513
- target.data = target.data * -1
514
- target.data = np.divide(
515
- np.subtract(target.data, target.data.mean()), target.data.std()
516
- )
517
- else:
518
- target.data = (
519
- -np.divide(
520
- np.subtract(target.data, target.data.min()),
521
- np.subtract(target.data.max(), target.data.min()),
522
- )
523
- + 1
524
- )
512
+ target.data = target.data * -1
525
513
 
526
514
  if args.output_format in ("extraction", "relion"):
527
515
  if not np.all(np.divide(target.shape, template.shape) > 2):
@@ -546,10 +534,14 @@ def main():
546
534
 
547
535
  working_directory = getcwd()
548
536
  if args.output_format == "relion":
537
+ name = [
538
+ join(working_directory, f"{args.output_prefix}_{index}.mrc")
539
+ for index in range(len(cand_slices))
540
+ ]
549
541
  orientations.to_file(
550
542
  filename=f"{args.output_prefix}.star",
551
543
  file_format="relion",
552
- name_prefix=join(working_directory, args.output_prefix),
544
+ name=name,
553
545
  ctf_image=args.wedge_mask,
554
546
  sampling_rate=target.sampling_rate.max(),
555
547
  subtomogram_size=extraction_shape[0],
@@ -606,24 +598,22 @@ def main():
606
598
  if args.output_format == "average":
607
599
  orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
608
600
  target_shape=target.shape,
609
- extraction_shape=np.multiply(template.shape, 2),
601
+ extraction_shape=template.shape,
610
602
  drop_out_of_box=True,
611
603
  return_orientations=True,
612
604
  )
613
605
  out = np.zeros_like(template.data)
614
- # out = np.zeros(np.multiply(template.shape, 2).astype(int))
615
606
  for index in range(len(cand_slices)):
616
- from scipy.spatial.transform import Rotation
617
-
618
- rotation = Rotation.from_euler(
619
- angles=orientations.rotations[index], seq="zyx", degrees=True
620
- )
621
- rotation_matrix = rotation.inv().as_matrix()
622
-
623
607
  subset = Density(target.data[obs_slices[index]])
624
- subset = subset.rigid_transform(rotation_matrix=rotation_matrix, order=1)
608
+ rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
625
609
 
610
+ subset = subset.rigid_transform(
611
+ rotation_matrix=np.linalg.inv(rotation_matrix),
612
+ order=1,
613
+ use_geometric_center=True,
614
+ )
626
615
  np.add(out, subset.data, out=out)
616
+
627
617
  out /= len(cand_slices)
628
618
  ret = Density(out, sampling_rate=template.sampling_rate, origin=0)
629
619
  ret.pad(template.shape, center=True)
@@ -637,17 +627,18 @@ def main():
637
627
  target_shape=target.shape,
638
628
  )
639
629
 
630
+ # Template is larger than target
640
631
  for index, (translation, angles, *_) in enumerate(orientations):
641
632
  rotation_matrix = euler_to_rotationmatrix(angles)
642
633
  if template_is_density:
643
- translation = np.subtract(translation, center)
644
634
  transformed_template = template.rigid_transform(
645
- rotation_matrix=rotation_matrix
646
- )
647
- transformed_template.origin = np.add(
648
- target_origin, np.multiply(translation, sampling_rate)
635
+ rotation_matrix=rotation_matrix, use_geometric_center=True
649
636
  )
650
637
 
638
+ # Just adapting the coordinate system not the in-box position
639
+ shift = np.multiply(np.subtract(translation, center), sampling_rate)
640
+ transformed_template.origin = np.add(target_origin, shift)
641
+
651
642
  else:
652
643
  template = Structure.from_file(cli_args.template)
653
644
  new_center_of_mass = np.add(
@@ -0,0 +1,132 @@
1
+ #!python
2
+ """ Preprocessing routines for template matching.
3
+
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ import warnings
9
+ import argparse
10
+ import numpy as np
11
+
12
+ from tme import Density, Structure
13
+ from tme.backends import backend as be
14
+ from tme.preprocessing.frequency_filters import BandPassFilter
15
+
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(
19
+ description="Perform template matching preprocessing.",
20
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
21
+ )
22
+
23
+ io_group = parser.add_argument_group("Input / Output")
24
+ io_group.add_argument(
25
+ "-m",
26
+ "--data",
27
+ dest="data",
28
+ type=str,
29
+ required=True,
30
+ help="Path to a file in PDB/MMCIF, CCP4/MRC, EM, H5 or a format supported by "
31
+ "tme.density.Density.from_file "
32
+ "https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
33
+ )
34
+ io_group.add_argument(
35
+ "-o",
36
+ "--output",
37
+ dest="output",
38
+ type=str,
39
+ required=True,
40
+ help="Path the output should be written to.",
41
+ )
42
+
43
+ box_group = parser.add_argument_group("Box")
44
+ box_group.add_argument(
45
+ "--box_size",
46
+ dest="box_size",
47
+ type=int,
48
+ required=True,
49
+ help="Box size of the output",
50
+ )
51
+ box_group.add_argument(
52
+ "--sampling_rate",
53
+ dest="sampling_rate",
54
+ type=float,
55
+ required=True,
56
+ help="Sampling rate of the output file.",
57
+ )
58
+
59
+ modulation_group = parser.add_argument_group("Modulation")
60
+ modulation_group.add_argument(
61
+ "--invert_contrast",
62
+ dest="invert_contrast",
63
+ action="store_true",
64
+ required=False,
65
+ help="Inverts the template contrast.",
66
+ )
67
+ modulation_group.add_argument(
68
+ "--lowpass",
69
+ dest="lowpass",
70
+ type=float,
71
+ required=False,
72
+ default=None,
73
+ help="Lowpass filter the template to the given resolution. Nyquist by default. "
74
+ "A value of 0 disables the filter.",
75
+ )
76
+ modulation_group.add_argument(
77
+ "--no_centering",
78
+ dest="no_centering",
79
+ action="store_true",
80
+ help="Assumes the template is already centered and omits centering.",
81
+ )
82
+ args = parser.parse_args()
83
+ return args
84
+
85
+
86
+ def main():
87
+ args = parse_args()
88
+
89
+ try:
90
+ data = Structure.from_file(args.data)
91
+ data = Density.from_structure(data, sampling_rate=args.sampling_rate)
92
+ except NotImplementedError:
93
+ data = Density.from_file(args.data)
94
+
95
+ if not args.no_centering:
96
+ data, _ = data.centered(0)
97
+
98
+ recommended_box = be.compute_convolution_shapes([args.box_size], [1])[1][0]
99
+ if recommended_box != args.box_size:
100
+ warnings.warn(
101
+ f"Consider using --box_size {recommended_box} instead of {args.box_size}."
102
+ )
103
+
104
+ data.pad(
105
+ np.multiply(args.box_size, np.divide(args.sampling_rate, data.sampling_rate)),
106
+ center=True,
107
+ )
108
+
109
+ bpf_mask = 1
110
+ lowpass = 2 * args.sampling_rate if args.lowpass is None else args.lowpass
111
+ if args.lowpass != 0:
112
+ bpf_mask = BandPassFilter(
113
+ lowpass=lowpass,
114
+ highpass=None,
115
+ use_gaussian=True,
116
+ return_real_fourier=True,
117
+ shape_is_real_fourier=False,
118
+ )(shape=data.shape)["data"]
119
+
120
+ data_ft = np.fft.rfftn(data.data, s=data.shape)
121
+ data_ft = np.multiply(data_ft, bpf_mask, out=data_ft)
122
+ data.data = np.fft.irfftn(data_ft, s=data.shape).real
123
+
124
+ data = data.resample(args.sampling_rate, method="spline", order=3)
125
+
126
+ if args.invert_contrast:
127
+ data.data = data.data * -1
128
+
129
+ data.to_file(args.output)
130
+
131
+ if __name__ == "__main__":
132
+ main()
@@ -132,14 +132,6 @@ def local_gaussian_filter(
132
132
  )
133
133
 
134
134
 
135
- def ntree(
136
- template: NDArray,
137
- sigma_range: Tuple[float, float],
138
- **kwargs: dict,
139
- ) -> NDArray:
140
- return preprocessor.ntree_filter(template=template, sigma_range=sigma_range)
141
-
142
-
143
135
  def mean(
144
136
  template: NDArray,
145
137
  width: int,
@@ -197,6 +189,10 @@ def compute_power_spectrum(template: NDArray) -> NDArray:
197
189
  return np.fft.fftshift(np.log(np.abs(np.fft.fftn(template))))
198
190
 
199
191
 
192
+ def invert_contrast(template: NDArray) -> NDArray:
193
+ return template * -1
194
+
195
+
200
196
  def widgets_from_function(function: Callable, exclude_params: List = ["self"]):
201
197
  """
202
198
  Creates list of magicui widgets by inspecting function typing ann
@@ -252,13 +248,13 @@ WRAPPED_FUNCTIONS = {
252
248
  "gaussian_filter": gaussian_filter,
253
249
  "bandpass_filter": bandpass_filter,
254
250
  "edge_gaussian_filter": edge_gaussian_filter,
255
- "ntree_filter": ntree,
256
251
  "local_gaussian_filter": local_gaussian_filter,
257
252
  "difference_of_gaussian_filter": difference_of_gaussian_filter,
258
253
  "mean_filter": mean,
259
254
  "wedge_filter": wedge,
260
255
  "power_spectrum": compute_power_spectrum,
261
256
  "ctf": ctf_filter,
257
+ "invert_contrast": invert_contrast,
262
258
  }
263
259
 
264
260
  EXCLUDED_FUNCTIONS = [
@@ -634,6 +630,7 @@ class MaskWidget(widgets.Container):
634
630
 
635
631
  data = active_layer.data.copy()
636
632
  cutoff = np.quantile(data, self.percentile_range_edit.value / 100)
633
+ cutoff = max(cutoff, np.finfo(np.float32).resolution)
637
634
  data[data < cutoff] = 0
638
635
 
639
636
  center_of_mass = Density.center_of_mass(np.abs(data), 0)