pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3.0.data/scripts/match_template.py +1106 -0
  3. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
  4. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
  6. pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
  7. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
  8. pytme-0.3.0.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +349 -378
  15. pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +320 -190
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +85 -19
  19. scripts/pytme_runner.py +771 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -53
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +396 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -201
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/extensions.cpython-311-darwin.so +0 -0
  49. tme/filters/__init__.py +3 -3
  50. tme/filters/_utils.py +36 -10
  51. tme/filters/bandpass.py +229 -188
  52. tme/filters/compose.py +5 -4
  53. tme/filters/ctf.py +516 -254
  54. tme/filters/reconstruction.py +91 -32
  55. tme/filters/wedge.py +196 -135
  56. tme/filters/whitening.py +37 -42
  57. tme/matching_data.py +28 -39
  58. tme/matching_exhaustive.py +31 -27
  59. tme/matching_optimization.py +5 -4
  60. tme/matching_scores.py +25 -15
  61. tme/matching_utils.py +158 -28
  62. tme/memory.py +4 -3
  63. tme/orientations.py +22 -9
  64. tme/parser.py +114 -33
  65. tme/preprocessor.py +6 -5
  66. tme/rotations.py +10 -7
  67. tme/structure.py +4 -3
  68. pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
  69. pytme-0.2.9.dist-info/RECORD +0 -119
  70. pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
  71. scripts/estimate_ram_usage.py +0 -97
  72. tests/data/Maps/.DS_Store +0 -0
  73. tests/data/Structures/.DS_Store +0 -0
  74. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
  75. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
- #!python
2
- """ CLI for basic pyTME template matching functions.
1
+ #!python3
2
+ """ CLI interface for basic pyTME template matching functions.
3
3
 
4
4
  Copyright (c) 2023 European Molecular Biology Laboratory
5
5
 
@@ -12,34 +12,28 @@ from sys import exit
12
12
  from time import time
13
13
  from typing import Tuple
14
14
  from copy import deepcopy
15
- from os.path import exists
16
- from tempfile import gettempdir
15
+ from os.path import abspath, exists
17
16
 
18
17
  import numpy as np
19
18
 
20
- from tme.backends import backend as be
21
19
  from tme import Density, __version__
22
- from tme.matching_utils import scramble_phases, write_pickle
23
- from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
24
- from tme.rotations import (
25
- get_cone_rotations,
20
+ from tme.matching_utils import (
26
21
  get_rotation_matrices,
22
+ get_rotations_around_vector,
23
+ compute_parallelization_schedule,
24
+ scramble_phases,
25
+ generate_tempfile_name,
26
+ write_pickle,
27
27
  )
28
+ from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
28
29
  from tme.matching_data import MatchingData
29
30
  from tme.analyzer import (
30
31
  MaxScoreOverRotations,
31
32
  PeakCallerMaximumFilter,
32
33
  )
33
- from tme.filters import (
34
- CTF,
35
- Wedge,
36
- Compose,
37
- BandPassFilter,
38
- WedgeReconstructed,
39
- ReconstructFromTilt,
40
- LinearWhiteningFilter,
41
- )
42
-
34
+ from tme.backends import backend as be
35
+ from tme.preprocessing import Compose
36
+ from tme.scoring import flc_scoring2
43
37
 
44
38
  def get_func_fullname(func) -> str:
45
39
  """Returns the full name of the given function, including its module."""
@@ -50,8 +44,6 @@ def print_block(name: str, data: dict, label_width=20) -> None:
50
44
  """Prints a formatted block of information."""
51
45
  print(f"\n> {name}")
52
46
  for key, value in data.items():
53
- if isinstance(value, np.ndarray):
54
- value = value.shape
55
47
  formatted_value = str(value)
56
48
  print(f" - {key + ':':<{label_width}} {formatted_value}")
57
49
 
@@ -107,9 +99,7 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
107
99
  f"Expected shape of {mask_path} was {mask_target.shape},"
108
100
  f" got f{mask.shape}"
109
101
  )
110
- if not np.allclose(
111
- np.round(mask.sampling_rate, 2), np.round(mask_target.sampling_rate, 2)
112
- ):
102
+ if not np.allclose(mask.sampling_rate, mask_target.sampling_rate):
113
103
  raise ValueError(
114
104
  f"Expected sampling_rate of {mask_path} was {mask_target.sampling_rate}"
115
105
  f", got f{mask.sampling_rate}"
@@ -117,6 +107,50 @@ def load_and_validate_mask(mask_target: "Density", mask_path: str, **kwargs):
117
107
  return mask
118
108
 
119
109
 
110
+ def crop_data(data: Density, cutoff: float, data_mask: Density = None) -> bool:
111
+ """
112
+ Crop the provided data and mask to a smaller box based on a cutoff value.
113
+
114
+ Parameters
115
+ ----------
116
+ data : Density
117
+ The data that should be cropped.
118
+ cutoff : float
119
+ The threshold value to determine which parts of the data should be kept.
120
+ data_mask : Density, optional
121
+ A mask for the data that should be cropped.
122
+
123
+ Returns
124
+ -------
125
+ bool
126
+ Returns True if the data was adjusted (cropped), otherwise returns False.
127
+
128
+ Notes
129
+ -----
130
+ Cropping is performed in place.
131
+ """
132
+ if cutoff is None:
133
+ return False
134
+
135
+ box = data.trim_box(cutoff=cutoff)
136
+ box_mask = box
137
+ if data_mask is not None:
138
+ box_mask = data_mask.trim_box(cutoff=cutoff)
139
+ box = tuple(
140
+ slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
141
+ for arr, mask in zip(box, box_mask)
142
+ )
143
+ if box == tuple(slice(0, x) for x in data.shape):
144
+ return False
145
+
146
+ data.adjust_box(box)
147
+
148
+ if data_mask:
149
+ data_mask.adjust_box(box)
150
+
151
+ return True
152
+
153
+
120
154
  def parse_rotation_logic(args, ndim):
121
155
  if args.angular_sampling is not None:
122
156
  rotations = get_rotation_matrices(
@@ -131,66 +165,34 @@ def parse_rotation_logic(args, ndim):
131
165
  if args.axis_sampling is None:
132
166
  args.axis_sampling = args.cone_sampling
133
167
 
134
- rotations = get_cone_rotations(
168
+ rotations = get_rotations_around_vector(
135
169
  cone_angle=args.cone_angle,
136
170
  cone_sampling=args.cone_sampling,
137
171
  axis_angle=args.axis_angle,
138
172
  axis_sampling=args.axis_sampling,
139
173
  n_symmetry=args.axis_symmetry,
140
- axis=[0 if i != args.cone_axis else 1 for i in range(ndim)],
141
- reference=[0, 0, -1],
142
174
  )
143
175
  return rotations
144
176
 
145
177
 
146
- def compute_schedule(
147
- args,
148
- matching_data: MatchingData,
149
- callback_class,
150
- pad_edges: bool = False,
151
- ):
152
- # User requested target padding
153
- if args.pad_edges is True:
154
- pad_edges = True
155
-
156
- splits, schedule = matching_data.computation_schedule(
157
- matching_method=args.score,
158
- analyzer_method=callback_class.__name__,
159
- use_gpu=args.use_gpu,
160
- pad_fourier=False,
161
- pad_target_edges=pad_edges,
162
- available_memory=args.memory,
163
- max_cores=args.cores,
178
+ # TODO: Think about whether wedge mask should also be added to target
179
+ # For now leave it at the cost of incorrect upper bound on the scores
180
+ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
181
+ from tme.preprocessing import LinearWhiteningFilter, BandPassFilter
182
+ from tme.preprocessing.tilt_series import (
183
+ Wedge,
184
+ WedgeReconstructed,
185
+ ReconstructFromTilt,
164
186
  )
165
187
 
166
- if splits is None:
167
- print(
168
- "Found no suitable parallelization schedule. Consider increasing"
169
- " available RAM or decreasing number of cores."
170
- )
171
- exit(-1)
172
-
173
- n_splits = np.prod(list(splits.values()))
174
- if pad_edges is False and len(matching_data._target_dim) == 0 and n_splits > 1:
175
- args.pad_edges = True
176
- return compute_schedule(args, matching_data, callback_class, True)
177
- return splits, schedule
178
-
179
-
180
- def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Compose]:
181
- needs_reconstruction = False
182
188
  template_filter, target_filter = [], []
183
189
  if args.tilt_angles is not None:
184
- needs_reconstruction = args.tilt_weighting is not None
185
190
  try:
186
191
  wedge = Wedge.from_file(args.tilt_angles)
187
192
  wedge.weight_type = args.tilt_weighting
188
193
  if args.tilt_weighting in ("angle", None) and args.ctf_file is None:
189
194
  wedge = WedgeReconstructed(
190
- angles=wedge.angles,
191
- weight_wedge=args.tilt_weighting == "angle",
192
- opening_axis=args.wedge_axes[0],
193
- tilt_axis=args.wedge_axes[1],
195
+ angles=wedge.angles, weight_wedge=args.tilt_weighting == "angle"
194
196
  )
195
197
  except FileNotFoundError:
196
198
  tilt_step, create_continuous_wedge = None, True
@@ -229,29 +231,23 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
229
231
  angles=tilt_angles,
230
232
  weight_wedge=args.tilt_weighting == "angle",
231
233
  create_continuous_wedge=create_continuous_wedge,
232
- reconstruction_filter=args.reconstruction_filter,
233
- opening_axis=args.wedge_axes[0],
234
- tilt_axis=args.wedge_axes[1],
235
- )
236
- wedge_target = WedgeReconstructed(
237
- angles=(np.abs(np.min(tilt_angles)), np.abs(np.max(tilt_angles))),
238
- weight_wedge=False,
239
- create_continuous_wedge=True,
240
- opening_axis=args.wedge_axes[0],
241
- tilt_axis=args.wedge_axes[1],
242
234
  )
243
- target_filter.append(wedge_target)
244
235
 
236
+ wedge.opening_axis = args.wedge_axes[0]
237
+ wedge.tilt_axis = args.wedge_axes[1]
245
238
  wedge.sampling_rate = template.sampling_rate
246
239
  template_filter.append(wedge)
247
240
  if not isinstance(wedge, WedgeReconstructed):
248
- reconstruction_filter = ReconstructFromTilt(
249
- reconstruction_filter=args.reconstruction_filter,
250
- interpolation_order=args.reconstruction_interpolation_order,
241
+ template_filter.append(
242
+ ReconstructFromTilt(
243
+ reconstruction_filter=args.reconstruction_filter,
244
+ interpolation_order=args.reconstruction_interpolation_order,
245
+ )
251
246
  )
252
- template_filter.append(reconstruction_filter)
253
247
 
254
248
  if args.ctf_file is not None or args.defocus is not None:
249
+ from tme.preprocessing.tilt_series import CTF
250
+
255
251
  needs_reconstruction = True
256
252
  if args.ctf_file is not None:
257
253
  ctf = CTF.from_file(args.ctf_file)
@@ -263,7 +259,6 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
263
259
  "per micrograph."
264
260
  )
265
261
  ctf.angles = wedge.angles
266
- ctf.no_reconstruction = False
267
262
  ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
268
263
  else:
269
264
  needs_reconstruction = False
@@ -322,13 +317,18 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
322
317
  sampling_rate=template.sampling_rate,
323
318
  )
324
319
  template_filter.append(bandpass)
325
- target_filter.append(bandpass)
320
+
321
+ if not args.no_filter_target:
322
+ target_filter.append(bandpass)
326
323
 
327
324
  if args.whiten_spectrum:
328
325
  whitening_filter = LinearWhiteningFilter()
329
326
  template_filter.append(whitening_filter)
330
327
  target_filter.append(whitening_filter)
331
328
 
329
+ needs_reconstruction = any(
330
+ [isinstance(t, ReconstructFromTilt) for t in template_filter]
331
+ )
332
332
  if needs_reconstruction and args.reconstruction_filter is None:
333
333
  warnings.warn(
334
334
  "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
@@ -336,8 +336,6 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
336
336
 
337
337
  template_filter = Compose(template_filter) if len(template_filter) else None
338
338
  target_filter = Compose(target_filter) if len(target_filter) else None
339
- if args.no_filter_target:
340
- target_filter = None
341
339
 
342
340
  return template_filter, target_filter
343
341
 
@@ -345,7 +343,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
345
343
  def parse_args():
346
344
  parser = argparse.ArgumentParser(
347
345
  description="Perform template matching.",
348
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
346
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
349
347
  )
350
348
 
351
349
  io_group = parser.add_argument_group("Input / Output")
@@ -395,7 +393,8 @@ def parse_args():
395
393
  dest="invert_target_contrast",
396
394
  action="store_true",
397
395
  default=False,
398
- help="Invert the target's contrast for cases where templates to-be-matched have "
396
+ help="Invert the target's contrast and rescale linearly between zero and one. "
397
+ "This option is intended for targets where templates to-be-matched have "
399
398
  "negative values, e.g. tomograms.",
400
399
  )
401
400
  io_group.add_argument(
@@ -435,19 +434,6 @@ def parse_args():
435
434
  help="Half-angle of the cone to be sampled in degrees. Allows to sample a "
436
435
  "narrow interval around a known orientation, e.g. for surface oversampling.",
437
436
  )
438
- angular_group.add_argument(
439
- "--cone_axis",
440
- dest="cone_axis",
441
- type=check_positive,
442
- default=2,
443
- help="Principal axis to build cone around.",
444
- )
445
- angular_group.add_argument(
446
- "--invert_cone",
447
- dest="invert_cone",
448
- action="store_true",
449
- help="Invert cone handedness.",
450
- )
451
437
  angular_group.add_argument(
452
438
  "--cone_sampling",
453
439
  dest="cone_sampling",
@@ -529,7 +515,7 @@ def parse_args():
529
515
  required=False,
530
516
  type=float,
531
517
  default=0.85,
532
- help="Fraction of available memory to be used. Ignored if --ram is set.",
518
+ help="Fraction of available memory to be used. Ignored if --ram is set."
533
519
  )
534
520
  computation_group.add_argument(
535
521
  "--temp_directory",
@@ -544,6 +530,7 @@ def parse_args():
544
530
  choices=be.available_backends(),
545
531
  help="[Expert] Overwrite default computation backend.",
546
532
  )
533
+
547
534
  filter_group = parser.add_argument_group("Filters")
548
535
  filter_group.add_argument(
549
536
  "--lowpass",
@@ -590,8 +577,8 @@ def parse_args():
590
577
  type=str,
591
578
  required=False,
592
579
  default=None,
593
- help="Indices of wedge opening and tilt axis, e.g. '2,0' for a wedge open "
594
- "in z and tilted over the x-axis.",
580
+ help="Indices of wedge opening and tilt axis, e.g. 0,2 for a wedge that is open "
581
+ "in z-direction and tilted over the x axis.",
595
582
  )
596
583
  filter_group.add_argument(
597
584
  "--tilt_angles",
@@ -698,7 +685,7 @@ def parse_args():
698
685
  dest="no_flip_phase",
699
686
  action="store_false",
700
687
  required=False,
701
- help="Do not perform phase-flipping CTF correction.",
688
+ help="Perform phase-flipping CTF correction.",
702
689
  )
703
690
  ctf_group.add_argument(
704
691
  "--correct_defocus_gradient",
@@ -710,6 +697,22 @@ def parse_args():
710
697
  )
711
698
 
712
699
  performance_group = parser.add_argument_group("Performance")
700
+ performance_group.add_argument(
701
+ "--cutoff_target",
702
+ dest="cutoff_target",
703
+ type=float,
704
+ required=False,
705
+ default=None,
706
+ help="Target contour level (used for cropping).",
707
+ )
708
+ performance_group.add_argument(
709
+ "--cutoff_template",
710
+ dest="cutoff_template",
711
+ type=float,
712
+ required=False,
713
+ default=None,
714
+ help="Template contour level (used for cropping).",
715
+ )
713
716
  performance_group.add_argument(
714
717
  "--no_centering",
715
718
  dest="no_centering",
@@ -717,20 +720,30 @@ def parse_args():
717
720
  help="Assumes the template is already centered and omits centering.",
718
721
  )
719
722
  performance_group.add_argument(
720
- "--pad_edges",
721
- dest="pad_edges",
723
+ "--no_edge_padding",
724
+ dest="no_edge_padding",
725
+ action="store_true",
726
+ default=False,
727
+ help="Whether to not pad the edges of the target. Can be set if the target"
728
+ " has a well defined bounding box, e.g. a masked reconstruction.",
729
+ )
730
+ performance_group.add_argument(
731
+ "--no_fourier_padding",
732
+ dest="no_fourier_padding",
722
733
  action="store_true",
723
734
  default=False,
724
- help="Whether to pad the edges of the target. Useful if the target does not "
725
- "a well-defined bounding box. Defaults to True if splitting is required.",
735
+ help="Whether input arrays should not be zero-padded to full convolution shape "
736
+ "for numerical stability. When working with very large targets, e.g. tomograms, "
737
+ "it is safe to use this flag and benefit from the performance gain.",
726
738
  )
727
739
  performance_group.add_argument(
728
- "--pad_filter",
729
- dest="pad_filter",
740
+ "--no_filter_padding",
741
+ dest="no_filter_padding",
730
742
  action="store_true",
731
743
  default=False,
732
- help="Pads the filter to the shape of the target. Particularly useful for fast "
733
- "oscilating filters to avoid aliasing effects.",
744
+ help="Omits padding of optional template filters. Particularly effective when "
745
+ "the target is much larger than the template. However, for fast osciliating "
746
+ "filters setting this flag can introduce aliasing effects.",
734
747
  )
735
748
  performance_group.add_argument(
736
749
  "--interpolation_order",
@@ -777,7 +790,7 @@ def parse_args():
777
790
  dest="number_of_peaks",
778
791
  action="store_true",
779
792
  default=1000,
780
- help="Number of peaks to call, 1000 by default.",
793
+ help="Number of peaks to call, 1000 by default..",
781
794
  )
782
795
  args = parser.parse_args()
783
796
  args.version = __version__
@@ -786,9 +799,16 @@ def parse_args():
786
799
  args.interpolation_order = None
787
800
 
788
801
  if args.temp_directory is None:
789
- args.temp_directory = gettempdir()
802
+ default = abspath(".")
803
+ if os.environ.get("TMPDIR", None) is not None:
804
+ default = os.environ.get("TMPDIR")
805
+ args.temp_directory = default
790
806
 
791
807
  os.environ["TMPDIR"] = args.temp_directory
808
+
809
+ args.pad_target_edges = not args.no_edge_padding
810
+ args.pad_fourier = not args.no_fourier_padding
811
+
792
812
  if args.score not in MATCHING_EXHAUSTIVE_REGISTER:
793
813
  raise ValueError(
794
814
  f"score has to be one of {', '.join(MATCHING_EXHAUSTIVE_REGISTER.keys())}"
@@ -837,15 +857,15 @@ def main():
837
857
  try:
838
858
  template = Density.from_file(args.template)
839
859
  except Exception:
860
+ drop = target.metadata.get("batch_dimension", ())
861
+ keep = [i not in drop for i in range(target.data.ndim)]
840
862
  template = Density.from_structure(
841
863
  filename_or_structure=args.template,
842
- sampling_rate=target.sampling_rate,
864
+ sampling_rate=target.sampling_rate[keep],
843
865
  )
844
866
 
845
867
  if target.sampling_rate.size == template.sampling_rate.size:
846
- if not np.allclose(
847
- np.round(target.sampling_rate, 2), np.round(template.sampling_rate, 2)
848
- ):
868
+ if not np.allclose(target.sampling_rate, template.sampling_rate):
849
869
  print(
850
870
  f"Resampling template to {target.sampling_rate}. "
851
871
  "Consider providing a template with the same sampling rate as the target."
@@ -860,6 +880,9 @@ def main():
860
880
  )
861
881
 
862
882
  initial_shape = target.shape
883
+ is_cropped = crop_data(
884
+ data=target, data_mask=target_mask, cutoff=args.cutoff_target
885
+ )
863
886
  print_block(
864
887
  name="Target",
865
888
  data={
@@ -868,6 +891,13 @@ def main():
868
891
  "Final Shape": target.shape,
869
892
  },
870
893
  )
894
+ if is_cropped:
895
+ args.target = generate_tempfile_name(suffix=".mrc")
896
+ target.to_file(args.target)
897
+
898
+ if target_mask:
899
+ args.target_mask = generate_tempfile_name(suffix=".mrc")
900
+ target_mask.to_file(args.target_mask)
871
901
 
872
902
  if target_mask:
873
903
  print_block(
@@ -880,6 +910,8 @@ def main():
880
910
  )
881
911
 
882
912
  initial_shape = template.shape
913
+ _ = crop_data(data=template, data_mask=template_mask, cutoff=args.cutoff_template)
914
+
883
915
  translation = np.zeros(len(template.shape), dtype=np.float32)
884
916
  if not args.no_centering:
885
917
  template, translation = template.centered(0)
@@ -927,7 +959,7 @@ def main():
927
959
 
928
960
  if args.scramble_phases:
929
961
  template.data = scramble_phases(
930
- template.data, noise_proportion=1.0, normalize_power=False
962
+ template.data, noise_proportion=1.0, normalize_power=True
931
963
  )
932
964
 
933
965
  # Determine suitable backend for the selected operation
@@ -935,8 +967,10 @@ def main():
935
967
  if args.backend is not None:
936
968
  req_backend = args.backend
937
969
  if req_backend not in available_backends:
938
- raise ValueError("Requested backend is not available.")
939
- available_backends = [req_backend]
970
+ raise ValueError(
971
+ "Requested backend is not available."
972
+ )
973
+ available_backends = [req_backend,]
940
974
 
941
975
  be_selection = ("numpyfftw", "pytorch", "jax", "mlx")
942
976
  if args.use_gpu:
@@ -951,21 +985,23 @@ def main():
951
985
  available_backends.remove("jax")
952
986
  if args.use_gpu and "pytorch" in available_backends:
953
987
  available_backends = ("pytorch",)
954
-
955
- # dim_match = len(template.shape) == len(target.shape) <= 3
956
- # if dim_match and args.use_gpu and "jax" in available_backends:
957
- # args.interpolation_order = 1
958
- # available_backends = ["jax"]
988
+ if args.interpolation_order == 3:
989
+ raise NotImplementedError(
990
+ "Pytorch does not support --interpolation_order 3, 1 is supported."
991
+ )
992
+ ndim = len(template.shape)
993
+ if len(target.shape) == ndim and ndim <= 3 and args.use_gpu:
994
+ available_backends = ["jax", ]
959
995
 
960
996
  backend_preference = ("numpyfftw", "pytorch", "jax", "mlx")
961
997
  if args.use_gpu:
962
- backend_preference = ("cupy", "pytorch", "jax")
998
+ backend_preference = ("cupy", "jax", "pytorch")
963
999
  for pref in backend_preference:
964
1000
  if pref not in available_backends:
965
1001
  continue
966
1002
  be.change_backend(pref)
967
1003
  if pref == "pytorch":
968
- be.change_backend(pref, device="cuda" if args.use_gpu else "cpu")
1004
+ be.change_backend(pref, device = "cuda" if args.use_gpu else "cpu")
969
1005
 
970
1006
  if args.use_mixed_precision:
971
1007
  be.change_backend(
@@ -976,12 +1012,6 @@ def main():
976
1012
  )
977
1013
  break
978
1014
 
979
- if pref == "pytorch" and args.interpolation_order == 3:
980
- warnings.warn(
981
- "Pytorch does not support --interpolation_order 3, setting it to 1."
982
- )
983
- args.interpolation_order = 1
984
-
985
1015
  available_memory = be.get_available_memory() * be.device_count()
986
1016
  if args.memory is None:
987
1017
  args.memory = int(args.memory_scaling * available_memory)
@@ -999,17 +1029,49 @@ def main():
999
1029
  rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
1000
1030
  )
1001
1031
 
1002
- matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
1003
- matching_data.template_filter, matching_data.target_filter = setup_filter(
1032
+ matching_data.template_filter, matching_data.target_filter = setup_filter(
1004
1033
  args, template, target
1005
1034
  )
1006
1035
 
1007
- matching_data.set_matching_dimension(
1008
- target_dim=target.metadata.get("batch_dimension", None),
1009
- template_dim=template.metadata.get("batch_dimension", None),
1036
+ target_dims = target.metadata.get("batch_dimension", None)
1037
+ matching_data._set_matching_dimension(target_dims=target_dims, template_dims=None)
1038
+ args.score = "FLC" if target_dims is not None else args.score
1039
+ args.target_batch, args.template_batch = target_dims, None
1040
+
1041
+ template_box = matching_data._output_template_shape
1042
+ if not args.pad_fourier:
1043
+ template_box = tuple(0 for _ in range(len(template_box)))
1044
+
1045
+ target_padding = tuple(0 for _ in range(len(template_box)))
1046
+ if args.pad_target_edges:
1047
+ target_padding = matching_data._output_template_shape
1048
+
1049
+ splits, schedule = compute_parallelization_schedule(
1050
+ shape1=target.shape,
1051
+ shape2=tuple(int(x) for x in template_box),
1052
+ shape1_padding=tuple(int(x) for x in target_padding),
1053
+ max_cores=args.cores,
1054
+ max_ram=args.memory,
1055
+ split_only_outer=args.use_gpu,
1056
+ matching_method=args.score,
1057
+ analyzer_method=callback_class.__name__,
1058
+ backend=be._backend_name,
1059
+ float_nbytes=be.datatype_bytes(be._float_dtype),
1060
+ complex_nbytes=be.datatype_bytes(be._complex_dtype),
1061
+ integer_nbytes=be.datatype_bytes(be._int_dtype),
1062
+ split_axes=target_dims,
1010
1063
  )
1011
- splits, schedule = compute_schedule(args, matching_data, callback_class)
1012
1064
 
1065
+ if splits is None:
1066
+ print(
1067
+ "Found no suitable parallelization schedule. Consider increasing"
1068
+ " available RAM or decreasing number of cores."
1069
+ )
1070
+ exit(-1)
1071
+
1072
+ matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
1073
+ if target_dims is not None:
1074
+ matching_score = flc_scoring2
1013
1075
  n_splits = np.prod(list(splits.values()))
1014
1076
  target_split = ", ".join(
1015
1077
  [":".join([str(x) for x in axis]) for axis in splits.items()]
@@ -1021,7 +1083,8 @@ def main():
1021
1083
  "Center Template": not args.no_centering,
1022
1084
  "Scramble Template": args.scramble_phases,
1023
1085
  "Invert Contrast": args.invert_target_contrast,
1024
- "Extend Target Edges": args.pad_edges,
1086
+ "Extend Fourier Grid": not args.no_fourier_padding,
1087
+ "Extend Target Edges": not args.no_edge_padding,
1025
1088
  "Interpolation Order": args.interpolation_order,
1026
1089
  "Setup Function": f"{get_func_fullname(matching_setup)}",
1027
1090
  "Scoring Function": f"{get_func_fullname(matching_score)}",
@@ -1034,8 +1097,8 @@ def main():
1034
1097
  )
1035
1098
 
1036
1099
  compute_options = {
1037
- "Backend": be._BACKEND_REGISTRY[be._backend_name],
1038
- "Compute Devices": f"CPU [{args.cores}], GPU [{gpus_used}]",
1100
+ "Backend" :be._BACKEND_REGISTRY[be._backend_name],
1101
+ "Compute Devices" : f"CPU [{args.cores}], GPU [{gpus_used}]",
1039
1102
  "Use Mixed Precision": args.use_mixed_precision,
1040
1103
  "Assigned Memory [MB]": f"{args.memory // 1e6} [out of {available_memory//1e6}]",
1041
1104
  "Temporary Directory": args.temp_directory,
@@ -1057,7 +1120,6 @@ def main():
1057
1120
  "Tilt Angles": args.tilt_angles,
1058
1121
  "Tilt Weighting": args.tilt_weighting,
1059
1122
  "Reconstruction Filter": args.reconstruction_filter,
1060
- "Extend Filter Grid": args.pad_filter,
1061
1123
  }
1062
1124
  if args.ctf_file is not None or args.defocus is not None:
1063
1125
  filter_args["CTF File"] = args.ctf_file
@@ -1080,7 +1142,8 @@ def main():
1080
1142
  analyzer_args = {
1081
1143
  "score_threshold": args.score_threshold,
1082
1144
  "number_of_peaks": args.number_of_peaks,
1083
- "min_distance": max(template.shape) // 3,
1145
+ "min_distance" : max(template.shape) // 2,
1146
+ "min_boundary_distance" : max(template.shape) // 2,
1084
1147
  "use_memmap": args.use_memmap,
1085
1148
  }
1086
1149
  print_block(
@@ -1105,13 +1168,14 @@ def main():
1105
1168
  callback_class=callback_class,
1106
1169
  callback_class_args=analyzer_args,
1107
1170
  target_splits=splits,
1108
- pad_target_edges=args.pad_edges,
1109
- pad_template_filter=args.pad_filter,
1171
+ pad_target_edges=args.pad_target_edges,
1172
+ pad_fourier=args.pad_fourier,
1173
+ pad_template_filter=not args.no_filter_padding,
1110
1174
  interpolation_order=args.interpolation_order,
1111
1175
  )
1112
1176
 
1113
1177
  candidates = list(candidates) if candidates is not None else []
1114
- if issubclass(callback_class, MaxScoreOverRotations):
1178
+ if callback_class == MaxScoreOverRotations:
1115
1179
  if target_mask is not None and args.score != "MCC":
1116
1180
  candidates[0] *= target_mask.data
1117
1181
  with warnings.catch_warnings():
@@ -1123,6 +1187,7 @@ def main():
1123
1187
  x: np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1124
1188
  for i, x in candidates[3].items()
1125
1189
  }
1190
+ print(np.where(candidates[0] == candidates[0].max()), candidates[0].max())
1126
1191
  candidates.append((target.origin, template.origin, template.sampling_rate, args))
1127
1192
  write_pickle(data=candidates, filename=args.output)
1128
1193