pytme 0.2.0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/match_template.py +183 -69
  2. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/METADATA +1 -1
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +183 -69
  8. scripts/match_template_filters.py +193 -71
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +259 -117
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +20 -8
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +79 -60
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +85 -61
  24. tme/matching_exhaustive.py +222 -129
  25. tme/matching_optimization.py +117 -76
  26. tme/orientations.py +175 -55
  27. tme/preprocessing/_utils.py +17 -5
  28. tme/preprocessing/composable_filter.py +2 -1
  29. tme/preprocessing/compose.py +1 -2
  30. tme/preprocessing/frequency_filters.py +97 -41
  31. tme/preprocessing/tilt_series.py +137 -87
  32. tme/preprocessor.py +3 -0
  33. tme/structure.py +4 -1
  34. pytme-0.2.0.dist-info/RECORD +0 -72
  35. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  36. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  37. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  38. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  39. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  40. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
@@ -217,7 +217,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
217
217
  if args.tilt_weighting not in ("angle", None):
218
218
  raise ValueError(
219
219
  "Tilt weighting schemes other than 'angle' or 'None' require "
220
- "a specification of electron doses."
220
+ "a specification of electron doses via --tilt_angles."
221
221
  )
222
222
 
223
223
  wedge = Wedge(
@@ -240,31 +240,58 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
240
240
  wedge.sampling_rate = template.sampling_rate
241
241
  template_filter.append(wedge)
242
242
  if not isinstance(wedge, WedgeReconstructed):
243
- template_filter.append(ReconstructFromTilt(
244
- reconstruction_filter = args.reconstruction_filter
245
- ))
243
+ template_filter.append(
244
+ ReconstructFromTilt(
245
+ reconstruction_filter=args.reconstruction_filter,
246
+ interpolation_order=args.reconstruction_interpolation_order,
247
+ )
248
+ )
246
249
 
247
- if args.ctf_file is not None:
250
+ if args.ctf_file is not None or args.defocus is not None:
248
251
  from tme.preprocessing.tilt_series import CTF
249
252
 
250
- ctf = CTF.from_file(args.ctf_file)
251
- n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
252
- if n_tilts_ctfs != n_tils_angles:
253
- raise ValueError(
254
- f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
255
- f"recieved {n_tils_angles} tilt angles. Expected one angle "
256
- "per micrograph."
253
+ needs_reconstruction = True
254
+ if args.ctf_file is not None:
255
+ ctf = CTF.from_file(args.ctf_file)
256
+ n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
257
+ if n_tilts_ctfs != n_tils_angles:
258
+ raise ValueError(
259
+ f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
260
+ f"recieved {n_tils_angles} tilt angles. Expected one angle "
261
+ "per micrograph."
262
+ )
263
+ ctf.angles = wedge.angles
264
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
265
+ else:
266
+ needs_reconstruction = False
267
+ ctf = CTF(
268
+ defocus_x=[args.defocus],
269
+ phase_shift=[args.phase_shift],
270
+ defocus_y=None,
271
+ angles=[0],
272
+ shape=None,
273
+ return_real_fourier=True,
257
274
  )
258
- ctf.angles = wedge.angles
259
- ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
260
-
261
- if isinstance(template_filter[-1], ReconstructFromTilt):
275
+ ctf.sampling_rate = template.sampling_rate
276
+ ctf.flip_phase = not args.no_flip_phase
277
+ ctf.amplitude_contrast = args.amplitude_contrast
278
+ ctf.spherical_aberration = args.spherical_aberration
279
+ ctf.acceleration_voltage = args.acceleration_voltage * 1e3
280
+ ctf.correct_defocus_gradient = args.correct_defocus_gradient
281
+
282
+ if not needs_reconstruction:
283
+ template_filter.append(ctf)
284
+ elif isinstance(template_filter[-1], ReconstructFromTilt):
262
285
  template_filter.insert(-1, ctf)
263
286
  else:
264
287
  template_filter.insert(0, ctf)
265
- template_filter.insert(1, ReconstructFromTilt(
266
- reconstruction_filter = args.reconstruction_filter
267
- ))
288
+ template_filter.insert(
289
+ 1,
290
+ ReconstructFromTilt(
291
+ reconstruction_filter=args.reconstruction_filter,
292
+ interpolation_order=args.reconstruction_interpolation_order,
293
+ ),
294
+ )
268
295
 
269
296
  if args.lowpass or args.highpass is not None:
270
297
  lowpass, highpass = args.lowpass, args.highpass
@@ -293,6 +320,14 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
293
320
  template_filter.append(whitening_filter)
294
321
  target_filter.append(whitening_filter)
295
322
 
323
+ needs_reconstruction = any(
324
+ [isinstance(t, ReconstructFromTilt) for t in template_filter]
325
+ )
326
+ if needs_reconstruction and args.reconstruction_filter is None:
327
+ warnings.warn(
328
+ "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
329
+ )
330
+
296
331
  template_filter = Compose(template_filter) if len(template_filter) else None
297
332
  target_filter = Compose(target_filter) if len(target_filter) else None
298
333
 
@@ -510,7 +545,7 @@ def parse_args():
510
545
  dest="no_pass_smooth",
511
546
  action="store_false",
512
547
  default=True,
513
- help="Whether a hard edge filter should be used for --lowpass and --highpass."
548
+ help="Whether a hard edge filter should be used for --lowpass and --highpass.",
514
549
  )
515
550
  filter_group.add_argument(
516
551
  "--pass_format",
@@ -519,7 +554,7 @@ def parse_args():
519
554
  required=False,
520
555
  choices=["sampling_rate", "voxel", "frequency"],
521
556
  help="How values passed to --lowpass and --highpass should be interpreted. "
522
- "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom."
557
+ "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
523
558
  )
524
559
  filter_group.add_argument(
525
560
  "--whiten_spectrum",
@@ -561,23 +596,90 @@ def parse_args():
561
596
  "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
562
597
  "relion and grigorieff require electron doses in --tilt_angles weights column.",
563
598
  )
564
- # filter_group.add_argument(
565
- # "--ctf_file",
566
- # dest="ctf_file",
567
- # type=str,
568
- # required=False,
569
- # default=None,
570
- # help="Path to a file with CTF parameters from CTFFIND4.",
571
- # )
572
599
  filter_group.add_argument(
573
600
  "--reconstruction_filter",
574
601
  dest="reconstruction_filter",
575
602
  type=str,
576
603
  required=False,
577
- choices = ["ram-lak", "ramp", "shepp-logan", "cosine", "hamming"],
604
+ choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
578
605
  default=None,
579
606
  help="Filter applied when reconstructing (N+1)-D from N-D filters.",
580
607
  )
608
+ filter_group.add_argument(
609
+ "--reconstruction_interpolation_order",
610
+ dest="reconstruction_interpolation_order",
611
+ type=int,
612
+ default=1,
613
+ required=False,
614
+ help="Analogous to --interpolation_order but for reconstruction.",
615
+ )
616
+
617
+ ctf_group = parser.add_argument_group("Contrast Transfer Function")
618
+ ctf_group.add_argument(
619
+ "--ctf_file",
620
+ dest="ctf_file",
621
+ type=str,
622
+ required=False,
623
+ default=None,
624
+ help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
625
+ "interpreted as tilt obtained at the angle specified in --tilt_angles. ",
626
+ )
627
+ ctf_group.add_argument(
628
+ "--defocus",
629
+ dest="defocus",
630
+ type=float,
631
+ required=False,
632
+ default=None,
633
+ help="Defocus in units of sampling rate (typically Ångstrom). "
634
+ "Superseded by --ctf_file.",
635
+ )
636
+ ctf_group.add_argument(
637
+ "--phase_shift",
638
+ dest="phase_shift",
639
+ type=float,
640
+ required=False,
641
+ default=0,
642
+ help="Phase shift in degrees. Superseded by --ctf_file.",
643
+ )
644
+ ctf_group.add_argument(
645
+ "--acceleration_voltage",
646
+ dest="acceleration_voltage",
647
+ type=float,
648
+ required=False,
649
+ default=300,
650
+ help="Acceleration voltage in kV, defaults to 300.",
651
+ )
652
+ ctf_group.add_argument(
653
+ "--spherical_aberration",
654
+ dest="spherical_aberration",
655
+ type=float,
656
+ required=False,
657
+ default=2.7e7,
658
+ help="Spherical aberration in units of sampling rate (typically Ångstrom).",
659
+ )
660
+ ctf_group.add_argument(
661
+ "--amplitude_contrast",
662
+ dest="amplitude_contrast",
663
+ type=float,
664
+ required=False,
665
+ default=0.07,
666
+ help="Amplitude contrast, defaults to 0.07.",
667
+ )
668
+ ctf_group.add_argument(
669
+ "--no_flip_phase",
670
+ dest="no_flip_phase",
671
+ action="store_false",
672
+ required=False,
673
+ help="Whether the phase of the computed CTF should not be flipped.",
674
+ )
675
+ ctf_group.add_argument(
676
+ "--correct_defocus_gradient",
677
+ dest="correct_defocus_gradient",
678
+ action="store_true",
679
+ required=False,
680
+ help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
681
+ "defocus gradients.",
682
+ )
581
683
 
582
684
  performance_group = parser.add_argument_group("Performance")
583
685
  performance_group.add_argument(
@@ -655,12 +757,11 @@ def parse_args():
655
757
  )
656
758
 
657
759
  args = parser.parse_args()
760
+ args.version = __version__
658
761
 
659
762
  if args.interpolation_order < 0:
660
763
  args.interpolation_order = None
661
764
 
662
- args.ctf_file = None
663
-
664
765
  if args.temp_directory is None:
665
766
  default = abspath(".")
666
767
  if os.environ.get("TMPDIR", None) is not None:
@@ -725,12 +826,13 @@ def main():
725
826
  sampling_rate=target.sampling_rate,
726
827
  )
727
828
 
728
- if not np.allclose(target.sampling_rate, template.sampling_rate):
729
- print(
730
- f"Resampling template to {target.sampling_rate}. "
731
- "Consider providing a template with the same sampling rate as the target."
732
- )
733
- template = template.resample(target.sampling_rate, order=3)
829
+ if target.sampling_rate.size == template.sampling_rate.size:
830
+ if not np.allclose(target.sampling_rate, template.sampling_rate):
831
+ print(
832
+ f"Resampling template to {target.sampling_rate}. "
833
+ "Consider providing a template with the same sampling rate as the target."
834
+ )
835
+ template = template.resample(target.sampling_rate, order=3)
734
836
 
735
837
  template_mask = load_and_validate_mask(
736
838
  mask_target=template, mask_path=args.template_mask
@@ -863,31 +965,46 @@ def main():
863
965
  if args.memory is None:
864
966
  args.memory = int(args.memory_scaling * available_memory)
865
967
 
866
- target_padding = np.zeros_like(template.shape)
867
- if args.pad_target_edges:
868
- target_padding = template.shape
968
+ callback_class = MaxScoreOverRotations
969
+ if args.peak_calling:
970
+ callback_class = PeakCallerMaximumFilter
971
+
972
+ matching_data = MatchingData(
973
+ target=target,
974
+ template=template.data,
975
+ target_mask=target_mask,
976
+ template_mask=template_mask,
977
+ invert_target=args.invert_target_contrast,
978
+ rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
979
+ )
980
+
981
+ template_filter, target_filter = setup_filter(args, template, target)
982
+ matching_data.template_filter = template_filter
983
+ matching_data.target_filter = target_filter
869
984
 
870
- template_box = template.shape
985
+ template_box = matching_data._output_template_shape
871
986
  if not args.pad_fourier:
872
987
  template_box = np.ones(len(template_box), dtype=int)
873
988
 
874
- callback_class = MaxScoreOverRotations
875
- if args.peak_calling:
876
- callback_class = PeakCallerMaximumFilter
989
+ target_padding = np.zeros(
990
+ (backend.size(matching_data._output_template_shape)), dtype=int
991
+ )
992
+ if args.pad_target_edges:
993
+ target_padding = matching_data._output_template_shape
877
994
 
878
995
  splits, schedule = compute_parallelization_schedule(
879
996
  shape1=target.shape,
880
- shape2=template_box,
881
- shape1_padding=target_padding,
997
+ shape2=tuple(int(x) for x in template_box),
998
+ shape1_padding=tuple(int(x) for x in target_padding),
882
999
  max_cores=args.cores,
883
1000
  max_ram=args.memory,
884
1001
  split_only_outer=args.use_gpu,
885
1002
  matching_method=args.score,
886
1003
  analyzer_method=callback_class.__name__,
887
1004
  backend=backend._backend_name,
888
- float_nbytes=backend.datatype_bytes(backend._default_dtype),
1005
+ float_nbytes=backend.datatype_bytes(backend._float_dtype),
889
1006
  complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
890
- integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
1007
+ integer_nbytes=backend.datatype_bytes(backend._int_dtype),
891
1008
  )
892
1009
 
893
1010
  if splits is None:
@@ -898,20 +1015,6 @@ def main():
898
1015
  exit(-1)
899
1016
 
900
1017
  matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
901
- matching_data = MatchingData(target=target, template=template.data)
902
- matching_data.rotations = parse_rotation_logic(args=args, ndim=target.data.ndim)
903
-
904
- template_filter, target_filter = setup_filter(args, template, target)
905
- matching_data.template_filter = template_filter
906
- matching_data.target_filter = target_filter
907
-
908
- matching_data.template_filter = template_filter
909
- matching_data._invert_target = args.invert_target_contrast
910
- if target_mask is not None:
911
- matching_data.target_mask = target_mask
912
- if template_mask is not None:
913
- matching_data.template_mask = template_mask.data
914
-
915
1018
  n_splits = np.prod(list(splits.values()))
916
1019
  target_split = ", ".join(
917
1020
  [":".join([str(x) for x in axis]) for axis in splits.items()]
@@ -945,13 +1048,23 @@ def main():
945
1048
  "Lowpass": args.lowpass,
946
1049
  "Highpass": args.highpass,
947
1050
  "Smooth Pass": args.no_pass_smooth,
948
- "Pass Format" : args.pass_format,
1051
+ "Pass Format": args.pass_format,
949
1052
  "Spectral Whitening": args.whiten_spectrum,
950
1053
  "Wedge Axes": args.wedge_axes,
951
1054
  "Tilt Angles": args.tilt_angles,
952
1055
  "Tilt Weighting": args.tilt_weighting,
953
- "CTF": args.ctf_file,
1056
+ "Reconstruction Filter": args.reconstruction_filter,
954
1057
  }
1058
+ if args.ctf_file is not None or args.defocus is not None:
1059
+ filter_args["CTF File"] = args.ctf_file
1060
+ filter_args["Defocus"] = args.defocus
1061
+ filter_args["Phase Shift"] = args.phase_shift
1062
+ filter_args["No Flip Phase"] = args.no_flip_phase
1063
+ filter_args["Acceleration Voltage"] = args.acceleration_voltage
1064
+ filter_args["Spherical Aberration"] = args.spherical_aberration
1065
+ filter_args["Amplitude Contrast"] = args.amplitude_contrast
1066
+ filter_args["Correct Defocus"] = args.correct_defocus_gradient
1067
+
955
1068
  filter_args = {k: v for k, v in filter_args.items() if v is not None}
956
1069
  if len(filter_args):
957
1070
  print_block(
@@ -1000,15 +1113,16 @@ def main():
1000
1113
  candidates[0] *= target_mask.data
1001
1114
  with warnings.catch_warnings():
1002
1115
  warnings.simplefilter("ignore", category=UserWarning)
1116
+ nbytes = backend.datatype_bytes(backend._float_dtype)
1117
+ dtype = np.float32 if nbytes == 4 else np.float16
1118
+ rot_dim = matching_data.rotations.shape[1]
1003
1119
  candidates[3] = {
1004
1120
  x: euler_from_rotationmatrix(
1005
- np.frombuffer(i, dtype=matching_data.rotations.dtype).reshape(
1006
- candidates[0].ndim, candidates[0].ndim
1007
- )
1121
+ np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1008
1122
  )
1009
1123
  for i, x in candidates[3].items()
1010
1124
  }
1011
- candidates.append((target.origin, template.origin, target.sampling_rate, args))
1125
+ candidates.append((target.origin, template.origin, template.sampling_rate, args))
1012
1126
  write_pickle(data=candidates, filename=args.output)
1013
1127
 
1014
1128
  runtime = time() - start
@@ -9,7 +9,7 @@ import argparse
9
9
  from sys import exit
10
10
  from os import getcwd
11
11
  from os.path import join, abspath
12
- from typing import List
12
+ from typing import List, Tuple
13
13
  from os.path import splitext
14
14
 
15
15
  import numpy as np
@@ -29,6 +29,7 @@ from tme.matching_utils import (
29
29
  euler_to_rotationmatrix,
30
30
  euler_from_rotationmatrix,
31
31
  )
32
+ from tme.matching_optimization import create_score_object, optimize_match
32
33
 
33
34
  PEAK_CALLERS = {
34
35
  "PeakCallerSort": PeakCallerSort,
@@ -181,6 +182,13 @@ def parse_args():
181
182
  required=False,
182
183
  help="Number of accepted false-positives picks to determine minimum score.",
183
184
  )
185
+ additional_group.add_argument(
186
+ "--local_optimization",
187
+ action="store_true",
188
+ required=False,
189
+ help="[Experimental] Perform local optimization of candidates. Useful when the "
190
+ "number of identified candidats is small (< 10).",
191
+ )
184
192
 
185
193
  args = parser.parse_args()
186
194
 
@@ -195,28 +203,34 @@ def parse_args():
195
203
 
196
204
  if args.minimum_score is not None or args.n_false_positives is not None:
197
205
  args.number_of_peaks = np.iinfo(np.int64).max
198
- else:
206
+ elif args.number_of_peaks is None:
199
207
  args.number_of_peaks = 1000
200
208
 
201
209
  return args
202
210
 
203
211
 
204
- def load_template(filepath: str, sampling_rate: NDArray, center: bool = True):
212
+ def load_template(
213
+ filepath: str,
214
+ sampling_rate: NDArray,
215
+ centering: bool = True,
216
+ target_shape: Tuple[int] = None,
217
+ ):
205
218
  try:
206
219
  template = Density.from_file(filepath)
207
- center_of_mass = template.center_of_mass(template.data)
220
+ center = np.divide(np.subtract(template.shape, 1), 2)
208
221
  template_is_density = True
209
- except ValueError:
222
+ except Exception:
210
223
  template = Structure.from_file(filepath)
211
- center_of_mass = template.center_of_mass()[::-1]
224
+ center = template.center_of_mass()[::-1]
212
225
  template = Density.from_structure(template, sampling_rate=sampling_rate)
213
226
  template_is_density = False
214
227
 
215
- translation = np.zeros_like(center_of_mass)
216
- if center:
228
+ translation = np.zeros_like(center)
229
+ if centering and template_is_density:
217
230
  template, translation = template.centered(0)
231
+ center = np.divide(np.subtract(template.shape, 1), 2)
218
232
 
219
- return template, center_of_mass, translation, template_is_density
233
+ return template, center, translation, template_is_density
220
234
 
221
235
 
222
236
  def merge_outputs(data, filepaths: List[str], args):
@@ -226,7 +240,7 @@ def merge_outputs(data, filepaths: List[str], args):
226
240
  if data[0].ndim != data[2].ndim:
227
241
  return data, 1
228
242
 
229
- from tme.matching_exhaustive import _normalize_under_mask
243
+ from tme.matching_exhaustive import normalize_under_mask
230
244
 
231
245
  def _norm_scores(data, args):
232
246
  target_origin, _, sampling_rate, cli_args = data[-1]
@@ -235,7 +249,7 @@ def merge_outputs(data, filepaths: List[str], args):
235
249
  ret = load_template(
236
250
  filepath=cli_args.template,
237
251
  sampling_rate=sampling_rate,
238
- center=not cli_args.no_centering,
252
+ centering=not cli_args.no_centering,
239
253
  )
240
254
  template, center_of_mass, translation, template_is_density = ret
241
255
 
@@ -256,7 +270,7 @@ def merge_outputs(data, filepaths: List[str], args):
256
270
  mask.shape, np.multiply(args.min_boundary_distance, 2)
257
271
  ).astype(int)
258
272
  mask[cropped_shape] = 0
259
- _normalize_under_mask(template=data[0], mask=mask, mask_intensity=mask.sum())
273
+ normalize_under_mask(template=data[0], mask=mask, mask_intensity=mask.sum())
260
274
  return data[0]
261
275
 
262
276
  entities = np.zeros_like(data[0])
@@ -280,7 +294,7 @@ def main():
280
294
  ret = load_template(
281
295
  filepath=cli_args.template,
282
296
  sampling_rate=sampling_rate,
283
- center=not cli_args.no_centering,
297
+ centering=not cli_args.no_centering,
284
298
  )
285
299
  template, center_of_mass, translation, template_is_density = ret
286
300
 
@@ -310,7 +324,9 @@ def main():
310
324
  max_shape = np.max(template.shape)
311
325
  args.min_boundary_distance = np.ceil(np.divide(max_shape, 2))
312
326
 
313
- # data, entities = merge_outputs(data=data, filepaths=args.input_file[1:], args=args)
327
+ entities = None
328
+ if len(args.input_file) > 1:
329
+ data, entities = merge_outputs(data=data, filepaths=args.input_file, args=args)
314
330
 
315
331
  orientations = args.orientations
316
332
  if orientations is None:
@@ -346,31 +362,33 @@ def main():
346
362
  minimum_score = max(minimum_score, 0)
347
363
  args.minimum_score = minimum_score
348
364
 
349
- peak_caller = PEAK_CALLERS[args.peak_caller](
350
- number_of_peaks=args.number_of_peaks,
351
- min_distance=args.min_distance,
352
- min_boundary_distance=args.min_boundary_distance,
353
- )
354
- if args.minimum_score is not None:
355
- args.number_of_peaks = np.inf
365
+ args.batch_dims = None
366
+ if hasattr(cli_args, "target_batch"):
367
+ args.batch_dims = cli_args.target_batch
368
+
369
+ peak_caller_kwargs = {
370
+ "number_of_peaks": args.number_of_peaks,
371
+ "min_distance": args.min_distance,
372
+ "min_boundary_distance": args.min_boundary_distance,
373
+ "batch_dims": args.batch_dims,
374
+ }
356
375
 
376
+ peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
357
377
  peak_caller(
358
378
  scores,
359
- rotation_matrix=np.eye(3),
379
+ rotation_matrix=np.eye(template.data.ndim),
360
380
  mask=template.data,
361
381
  rotation_mapping=rotation_mapping,
362
382
  rotation_array=rotation_array,
363
383
  minimum_score=args.minimum_score,
364
384
  )
365
385
  candidates = peak_caller.merge(
366
- candidates=[tuple(peak_caller)],
367
- number_of_peaks=args.number_of_peaks,
368
- min_distance=args.min_distance,
369
- min_boundary_distance=args.min_boundary_distance,
386
+ candidates=[tuple(peak_caller)], **peak_caller_kwargs
370
387
  )
371
388
  if len(candidates) == 0:
372
- print("Found no peaks. Consider changing peak calling parameters.")
373
- exit(-1)
389
+ candidates = [[], [], [], []]
390
+ print("Found no peaks, consider changing peak calling parameters.")
391
+ exit(0)
374
392
 
375
393
  for translation, _, score, detail in zip(*candidates):
376
394
  rotations.append(rotation_mapping[rotation_array[tuple(translation)]])
@@ -381,8 +399,13 @@ def main():
381
399
  for i in range(translation.shape[0]):
382
400
  rotations.append(euler_from_rotationmatrix(rotation[i]))
383
401
 
384
- rotations = np.vstack(rotations).astype(float)
402
+ if len(rotations):
403
+ rotations = np.vstack(rotations).astype(float)
385
404
  translations, scores, details = candidates[0], candidates[2], candidates[3]
405
+
406
+ if entities is not None:
407
+ details = entities[tuple(translations.T)]
408
+
386
409
  orientations = Orientations(
387
410
  translations=translations,
388
411
  rotations=rotations,
@@ -390,14 +413,55 @@ def main():
390
413
  details=details,
391
414
  )
392
415
 
393
- if args.minimum_score is not None:
416
+ if args.minimum_score is not None and len(orientations.scores):
394
417
  keep = orientations.scores >= args.minimum_score
395
418
  orientations = orientations[keep]
396
419
 
397
- if args.maximum_score is not None:
420
+ if args.maximum_score is not None and len(orientations.scores):
398
421
  keep = orientations.scores <= args.maximum_score
399
422
  orientations = orientations[keep]
400
423
 
424
+ if args.peak_oversampling > 1:
425
+ peak_caller = peak_caller = PEAK_CALLERS[args.peak_caller]()
426
+ if data[0].ndim != data[2].ndim:
427
+ print(
428
+ "Input pickle does not contain template matching scores."
429
+ " Cannot oversample peaks."
430
+ )
431
+ exit(-1)
432
+ orientations.translations = peak_caller.oversample_peaks(
433
+ score_space=data[0],
434
+ peak_positions=orientations.translations,
435
+ oversampling_factor=args.peak_oversampling,
436
+ )
437
+
438
+ if args.local_optimization:
439
+ target = Density.from_file(cli_args.target)
440
+ orientations.translations = orientations.translations.astype(np.float32)
441
+ orientations.rotations = orientations.rotations.astype(np.float32)
442
+ for index, (translation, angles, *_) in enumerate(orientations):
443
+ score_object = create_score_object(
444
+ score="FLC",
445
+ target=target.data.copy(),
446
+ template=template.data.copy(),
447
+ template_mask=template_mask.data.copy(),
448
+ )
449
+
450
+ center = np.divide(template.shape, 2)
451
+ init_translation = np.subtract(translation, center)
452
+ bounds_translation = tuple((x - 5, x + 5) for x in init_translation)
453
+
454
+ translation, rotation_matrix, score = optimize_match(
455
+ score_object=score_object,
456
+ optimization_method="basinhopping",
457
+ bounds_translation=bounds_translation,
458
+ maxiter=3,
459
+ x0=[*init_translation, *angles],
460
+ )
461
+ orientations.translations[index] = np.add(translation, center)
462
+ orientations.rotations[index] = angles
463
+ orientations.scores[index] = score * -1
464
+
401
465
  if args.output_format == "orientations":
402
466
  orientations.to_file(filename=f"{args.output_prefix}.tsv", file_format="text")
403
467
  exit(0)
@@ -515,7 +579,6 @@ def main():
515
579
  )
516
580
  rotation_matrix = rotation.inv().as_matrix()
517
581
 
518
- # rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
519
582
  subset = Density(target.data[obs_slices[index]])
520
583
  subset = subset.rigid_transform(rotation_matrix=rotation_matrix, order=1)
521
584
 
@@ -526,35 +589,30 @@ def main():
526
589
  ret.to_file(f"{args.output_prefix}_average.mrc")
527
590
  exit(0)
528
591
 
529
- if args.peak_oversampling > 1:
530
- peak_caller = peak_caller = PEAK_CALLERS[args.peak_caller]()
531
- if data[0].ndim != data[2].ndim:
532
- print(
533
- "Input pickle does not contain template matching scores."
534
- " Cannot oversample peaks."
535
- )
536
- exit(-1)
537
- orientations.translations = peak_caller.oversample_peaks(
538
- score_space=data[0],
539
- translations=orientations.translations,
540
- oversampling_factor=args.oversampling_factor,
541
- )
592
+ template, center, *_ = load_template(
593
+ filepath=cli_args.template,
594
+ sampling_rate=sampling_rate,
595
+ centering=not cli_args.no_centering,
596
+ target_shape=target.shape,
597
+ )
542
598
 
543
599
  for index, (translation, angles, *_) in enumerate(orientations):
544
600
  rotation_matrix = euler_to_rotationmatrix(angles)
545
601
  if template_is_density:
546
- translation = np.subtract(translation, center_of_mass)
602
+ translation = np.subtract(translation, center)
547
603
  transformed_template = template.rigid_transform(
548
604
  rotation_matrix=rotation_matrix
549
605
  )
550
- new_origin = np.add(target_origin / sampling_rate, translation)
551
- transformed_template.origin = np.multiply(new_origin, sampling_rate)
606
+ transformed_template.origin = np.add(
607
+ target_origin, np.multiply(translation, sampling_rate)
608
+ )
609
+
552
610
  else:
553
611
  template = Structure.from_file(cli_args.template)
554
612
  new_center_of_mass = np.add(
555
613
  np.multiply(translation, sampling_rate), target_origin
556
614
  )
557
- translation = np.subtract(new_center_of_mass, center_of_mass)
615
+ translation = np.subtract(new_center_of_mass, center)
558
616
  transformed_template = template.rigid_transform(
559
617
  translation=translation[::-1],
560
618
  rotation_matrix=rotation_matrix[::-1, ::-1],
@@ -789,7 +789,10 @@ class AlignmentWidget(widgets.Container):
789
789
  active_layer = self.viewer.layers.selection.active
790
790
  if active_layer is None:
791
791
  return ()
792
- return [i for i in range(active_layer.data.ndim)]
792
+ try:
793
+ return [i for i in range(active_layer.data.ndim)]
794
+ except Exception:
795
+ return ()
793
796
 
794
797
  def _update_align_axis(self, *args):
795
798
  self.align_axis.choices = self._get_active_layer_dims()