pytme 0.2.0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/match_template.py +183 -69
  2. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/METADATA +1 -1
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +183 -69
  8. scripts/match_template_filters.py +193 -71
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +259 -117
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +20 -8
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +79 -60
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +85 -61
  24. tme/matching_exhaustive.py +222 -129
  25. tme/matching_optimization.py +117 -76
  26. tme/orientations.py +175 -55
  27. tme/preprocessing/_utils.py +17 -5
  28. tme/preprocessing/composable_filter.py +2 -1
  29. tme/preprocessing/compose.py +1 -2
  30. tme/preprocessing/frequency_filters.py +97 -41
  31. tme/preprocessing/tilt_series.py +137 -87
  32. tme/preprocessor.py +3 -0
  33. tme/structure.py +4 -1
  34. pytme-0.2.0.dist-info/RECORD +0 -72
  35. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  36. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  37. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  38. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  39. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  40. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
@@ -178,6 +178,7 @@ def parse_rotation_logic(args, ndim):
178
178
 
179
179
 
180
180
  # TODO: Think about whether wedge mask should also be added to target
181
+ # For now leave it at the cost of incorrect upper bound on the scores
181
182
  def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
182
183
  from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
183
184
  from tme.preprocessing.tilt_series import (
@@ -216,7 +217,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
216
217
  if args.tilt_weighting not in ("angle", None):
217
218
  raise ValueError(
218
219
  "Tilt weighting schemes other than 'angle' or 'None' require "
219
- "a specification of electron doses."
220
+ "a specification of electron doses via --tilt_angles."
220
221
  )
221
222
 
222
223
  wedge = Wedge(
@@ -239,31 +240,58 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
239
240
  wedge.sampling_rate = template.sampling_rate
240
241
  template_filter.append(wedge)
241
242
  if not isinstance(wedge, WedgeReconstructed):
242
- template_filter.append(ReconstructFromTilt(
243
- reconstruction_filter = args.reconstruction_filter
244
- ))
243
+ template_filter.append(
244
+ ReconstructFromTilt(
245
+ reconstruction_filter=args.reconstruction_filter,
246
+ interpolation_order=args.reconstruction_interpolation_order,
247
+ )
248
+ )
245
249
 
246
- if args.ctf_file is not None:
250
+ if args.ctf_file is not None or args.defocus is not None:
247
251
  from tme.preprocessing.tilt_series import CTF
248
252
 
249
- ctf = CTF.from_file(args.ctf_file)
250
- n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
251
- if n_tilts_ctfs != n_tils_angles:
252
- raise ValueError(
253
- f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
254
- f"recieved {n_tils_angles} tilt angles. Expected one angle "
255
- "per micrograph."
253
+ needs_reconstruction = True
254
+ if args.ctf_file is not None:
255
+ ctf = CTF.from_file(args.ctf_file)
256
+ n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
257
+ if n_tilts_ctfs != n_tils_angles:
258
+ raise ValueError(
259
+ f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
260
+ f"recieved {n_tils_angles} tilt angles. Expected one angle "
261
+ "per micrograph."
262
+ )
263
+ ctf.angles = wedge.angles
264
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
265
+ else:
266
+ needs_reconstruction = False
267
+ ctf = CTF(
268
+ defocus_x=[args.defocus],
269
+ phase_shift=[args.phase_shift],
270
+ defocus_y=None,
271
+ angles=[0],
272
+ shape=None,
273
+ return_real_fourier=True,
256
274
  )
257
- ctf.angles = wedge.angles
258
- ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
259
-
260
- if isinstance(template_filter[-1], ReconstructFromTilt):
275
+ ctf.sampling_rate = template.sampling_rate
276
+ ctf.flip_phase = not args.no_flip_phase
277
+ ctf.amplitude_contrast = args.amplitude_contrast
278
+ ctf.spherical_aberration = args.spherical_aberration
279
+ ctf.acceleration_voltage = args.acceleration_voltage * 1e3
280
+ ctf.correct_defocus_gradient = args.correct_defocus_gradient
281
+
282
+ if not needs_reconstruction:
283
+ template_filter.append(ctf)
284
+ elif isinstance(template_filter[-1], ReconstructFromTilt):
261
285
  template_filter.insert(-1, ctf)
262
286
  else:
263
287
  template_filter.insert(0, ctf)
264
- template_filter.insert(1, ReconstructFromTilt(
265
- reconstruction_filter = args.reconstruction_filter
266
- ))
288
+ template_filter.insert(
289
+ 1,
290
+ ReconstructFromTilt(
291
+ reconstruction_filter=args.reconstruction_filter,
292
+ interpolation_order=args.reconstruction_interpolation_order,
293
+ ),
294
+ )
267
295
 
268
296
  if args.lowpass or args.highpass is not None:
269
297
  lowpass, highpass = args.lowpass, args.highpass
@@ -292,6 +320,14 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
292
320
  template_filter.append(whitening_filter)
293
321
  target_filter.append(whitening_filter)
294
322
 
323
+ needs_reconstruction = any(
324
+ [isinstance(t, ReconstructFromTilt) for t in template_filter]
325
+ )
326
+ if needs_reconstruction and args.reconstruction_filter is None:
327
+ warnings.warn(
328
+ "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
329
+ )
330
+
295
331
  template_filter = Compose(template_filter) if len(template_filter) else None
296
332
  target_filter = Compose(target_filter) if len(target_filter) else None
297
333
 
@@ -509,7 +545,7 @@ def parse_args():
509
545
  dest="no_pass_smooth",
510
546
  action="store_false",
511
547
  default=True,
512
- help="Whether a hard edge filter should be used for --lowpass and --highpass."
548
+ help="Whether a hard edge filter should be used for --lowpass and --highpass.",
513
549
  )
514
550
  filter_group.add_argument(
515
551
  "--pass_format",
@@ -518,7 +554,7 @@ def parse_args():
518
554
  required=False,
519
555
  choices=["sampling_rate", "voxel", "frequency"],
520
556
  help="How values passed to --lowpass and --highpass should be interpreted. "
521
- "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom."
557
+ "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
522
558
  )
523
559
  filter_group.add_argument(
524
560
  "--whiten_spectrum",
@@ -560,23 +596,90 @@ def parse_args():
560
596
  "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
561
597
  "relion and grigorieff require electron doses in --tilt_angles weights column.",
562
598
  )
563
- # filter_group.add_argument(
564
- # "--ctf_file",
565
- # dest="ctf_file",
566
- # type=str,
567
- # required=False,
568
- # default=None,
569
- # help="Path to a file with CTF parameters from CTFFIND4.",
570
- # )
571
599
  filter_group.add_argument(
572
600
  "--reconstruction_filter",
573
601
  dest="reconstruction_filter",
574
602
  type=str,
575
603
  required=False,
576
- choices = ["ram-lak", "ramp", "shepp-logan", "cosine", "hamming"],
604
+ choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
577
605
  default=None,
578
606
  help="Filter applied when reconstructing (N+1)-D from N-D filters.",
579
607
  )
608
+ filter_group.add_argument(
609
+ "--reconstruction_interpolation_order",
610
+ dest="reconstruction_interpolation_order",
611
+ type=int,
612
+ default=1,
613
+ required=False,
614
+ help="Analogous to --interpolation_order but for reconstruction.",
615
+ )
616
+
617
+ ctf_group = parser.add_argument_group("Contrast Transfer Function")
618
+ ctf_group.add_argument(
619
+ "--ctf_file",
620
+ dest="ctf_file",
621
+ type=str,
622
+ required=False,
623
+ default=None,
624
+ help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
625
+ "interpreted as tilt obtained at the angle specified in --tilt_angles. ",
626
+ )
627
+ ctf_group.add_argument(
628
+ "--defocus",
629
+ dest="defocus",
630
+ type=float,
631
+ required=False,
632
+ default=None,
633
+ help="Defocus in units of sampling rate (typically Ångstrom). "
634
+ "Superseded by --ctf_file.",
635
+ )
636
+ ctf_group.add_argument(
637
+ "--phase_shift",
638
+ dest="phase_shift",
639
+ type=float,
640
+ required=False,
641
+ default=0,
642
+ help="Phase shift in degrees. Superseded by --ctf_file.",
643
+ )
644
+ ctf_group.add_argument(
645
+ "--acceleration_voltage",
646
+ dest="acceleration_voltage",
647
+ type=float,
648
+ required=False,
649
+ default=300,
650
+ help="Acceleration voltage in kV, defaults to 300.",
651
+ )
652
+ ctf_group.add_argument(
653
+ "--spherical_aberration",
654
+ dest="spherical_aberration",
655
+ type=float,
656
+ required=False,
657
+ default=2.7e7,
658
+ help="Spherical aberration in units of sampling rate (typically Ångstrom).",
659
+ )
660
+ ctf_group.add_argument(
661
+ "--amplitude_contrast",
662
+ dest="amplitude_contrast",
663
+ type=float,
664
+ required=False,
665
+ default=0.07,
666
+ help="Amplitude contrast, defaults to 0.07.",
667
+ )
668
+ ctf_group.add_argument(
669
+ "--no_flip_phase",
670
+ dest="no_flip_phase",
671
+ action="store_false",
672
+ required=False,
673
+ help="Whether the phase of the computed CTF should not be flipped.",
674
+ )
675
+ ctf_group.add_argument(
676
+ "--correct_defocus_gradient",
677
+ dest="correct_defocus_gradient",
678
+ action="store_true",
679
+ required=False,
680
+ help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
681
+ "defocus gradients.",
682
+ )
580
683
 
581
684
  performance_group = parser.add_argument_group("Performance")
582
685
  performance_group.add_argument(
@@ -654,12 +757,11 @@ def parse_args():
654
757
  )
655
758
 
656
759
  args = parser.parse_args()
760
+ args.version = __version__
657
761
 
658
762
  if args.interpolation_order < 0:
659
763
  args.interpolation_order = None
660
764
 
661
- args.ctf_file = None
662
-
663
765
  if args.temp_directory is None:
664
766
  default = abspath(".")
665
767
  if os.environ.get("TMPDIR", None) is not None:
@@ -719,17 +821,20 @@ def main():
719
821
  try:
720
822
  template = Density.from_file(args.template)
721
823
  except Exception:
824
+ drop = target.metadata.get("batch_dimension", ())
825
+ keep = [i not in drop for i in range(target.data.ndim)]
722
826
  template = Density.from_structure(
723
827
  filename_or_structure=args.template,
724
- sampling_rate=target.sampling_rate,
828
+ sampling_rate=target.sampling_rate[keep],
725
829
  )
726
830
 
727
- if not np.allclose(target.sampling_rate, template.sampling_rate):
728
- print(
729
- f"Resampling template to {target.sampling_rate}. "
730
- "Consider providing a template with the same sampling rate as the target."
731
- )
732
- template = template.resample(target.sampling_rate, order=3)
831
+ if target.sampling_rate.size == template.sampling_rate.size:
832
+ if not np.allclose(target.sampling_rate, template.sampling_rate):
833
+ print(
834
+ f"Resampling template to {target.sampling_rate}. "
835
+ "Consider providing a template with the same sampling rate as the target."
836
+ )
837
+ template = template.resample(target.sampling_rate, order=3)
733
838
 
734
839
  template_mask = load_and_validate_mask(
735
840
  mask_target=template, mask_path=args.template_mask
@@ -862,31 +967,52 @@ def main():
862
967
  if args.memory is None:
863
968
  args.memory = int(args.memory_scaling * available_memory)
864
969
 
865
- target_padding = np.zeros_like(template.shape)
866
- if args.pad_target_edges:
867
- target_padding = template.shape
970
+ callback_class = MaxScoreOverRotations
971
+ if args.peak_calling:
972
+ callback_class = PeakCallerMaximumFilter
973
+
974
+ matching_data = MatchingData(
975
+ target=target,
976
+ template=template.data,
977
+ target_mask=target_mask,
978
+ template_mask=template_mask,
979
+ invert_target=args.invert_target_contrast,
980
+ rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
981
+ )
868
982
 
869
- template_box = template.shape
983
+ template_filter, target_filter = setup_filter(args, template, target)
984
+ matching_data.template_filter = template_filter
985
+ matching_data.target_filter = target_filter
986
+
987
+ target_dims = target.metadata.get("batch_dimension", None)
988
+ matching_data._set_batch_dimension(target_dims=target_dims, template_dims=None)
989
+ args.score = "FLC2" if target_dims is not None else args.score
990
+ args.target_batch, args.template_batch = target_dims, None
991
+
992
+ template_box = matching_data._output_template_shape
870
993
  if not args.pad_fourier:
871
994
  template_box = np.ones(len(template_box), dtype=int)
872
995
 
873
- callback_class = MaxScoreOverRotations
874
- if args.peak_calling:
875
- callback_class = PeakCallerMaximumFilter
996
+ target_padding = np.zeros(
997
+ (backend.size(matching_data._output_template_shape)), dtype=int
998
+ )
999
+ if args.pad_target_edges:
1000
+ target_padding = matching_data._output_template_shape
876
1001
 
877
1002
  splits, schedule = compute_parallelization_schedule(
878
1003
  shape1=target.shape,
879
- shape2=template_box,
880
- shape1_padding=target_padding,
1004
+ shape2=tuple(int(x) for x in template_box),
1005
+ shape1_padding=tuple(int(x) for x in target_padding),
881
1006
  max_cores=args.cores,
882
1007
  max_ram=args.memory,
883
1008
  split_only_outer=args.use_gpu,
884
1009
  matching_method=args.score,
885
1010
  analyzer_method=callback_class.__name__,
886
1011
  backend=backend._backend_name,
887
- float_nbytes=backend.datatype_bytes(backend._default_dtype),
1012
+ float_nbytes=backend.datatype_bytes(backend._float_dtype),
888
1013
  complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
889
- integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
1014
+ integer_nbytes=backend.datatype_bytes(backend._int_dtype),
1015
+ split_axes=target_dims,
890
1016
  )
891
1017
 
892
1018
  if splits is None:
@@ -897,20 +1023,6 @@ def main():
897
1023
  exit(-1)
898
1024
 
899
1025
  matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
900
- matching_data = MatchingData(target=target, template=template.data)
901
- matching_data.rotations = parse_rotation_logic(args=args, ndim=target.data.ndim)
902
-
903
- template_filter, target_filter = setup_filter(args, template, target)
904
- matching_data.template_filter = template_filter
905
- matching_data.target_filter = target_filter
906
-
907
- matching_data.template_filter = template_filter
908
- matching_data._invert_target = args.invert_target_contrast
909
- if target_mask is not None:
910
- matching_data.target_mask = target_mask
911
- if template_mask is not None:
912
- matching_data.template_mask = template_mask.data
913
-
914
1026
  n_splits = np.prod(list(splits.values()))
915
1027
  target_split = ", ".join(
916
1028
  [":".join([str(x) for x in axis]) for axis in splits.items()]
@@ -944,13 +1056,23 @@ def main():
944
1056
  "Lowpass": args.lowpass,
945
1057
  "Highpass": args.highpass,
946
1058
  "Smooth Pass": args.no_pass_smooth,
947
- "Pass Format" : args.pass_format,
1059
+ "Pass Format": args.pass_format,
948
1060
  "Spectral Whitening": args.whiten_spectrum,
949
1061
  "Wedge Axes": args.wedge_axes,
950
1062
  "Tilt Angles": args.tilt_angles,
951
1063
  "Tilt Weighting": args.tilt_weighting,
952
- "CTF": args.ctf_file,
1064
+ "Reconstruction Filter": args.reconstruction_filter,
953
1065
  }
1066
+ if args.ctf_file is not None or args.defocus is not None:
1067
+ filter_args["CTF File"] = args.ctf_file
1068
+ filter_args["Defocus"] = args.defocus
1069
+ filter_args["Phase Shift"] = args.phase_shift
1070
+ filter_args["No Flip Phase"] = args.no_flip_phase
1071
+ filter_args["Acceleration Voltage"] = args.acceleration_voltage
1072
+ filter_args["Spherical Aberration"] = args.spherical_aberration
1073
+ filter_args["Amplitude Contrast"] = args.amplitude_contrast
1074
+ filter_args["Correct Defocus"] = args.correct_defocus_gradient
1075
+
954
1076
  filter_args = {k: v for k, v in filter_args.items() if v is not None}
955
1077
  if len(filter_args):
956
1078
  print_block(
@@ -999,16 +1121,16 @@ def main():
999
1121
  candidates[0] *= target_mask.data
1000
1122
  with warnings.catch_warnings():
1001
1123
  warnings.simplefilter("ignore", category=UserWarning)
1124
+ nbytes = backend.datatype_bytes(backend._float_dtype)
1125
+ dtype = np.float32 if nbytes == 4 else np.float16
1126
+ rot_dim = matching_data.rotations.shape[1]
1002
1127
  candidates[3] = {
1003
1128
  x: euler_from_rotationmatrix(
1004
- np.frombuffer(i, dtype=matching_data.rotations.dtype).reshape(
1005
- candidates[0].ndim, candidates[0].ndim
1006
- )
1129
+ np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1007
1130
  )
1008
1131
  for i, x in candidates[3].items()
1009
1132
  }
1010
-
1011
- candidates.append((target.origin, template.origin, target.sampling_rate, args))
1133
+ candidates.append((target.origin, template.origin, template.sampling_rate, args))
1012
1134
  write_pickle(data=candidates, filename=args.output)
1013
1135
 
1014
1136
  runtime = time() - start