pytme 0.1.8__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/match_template.py +148 -126
  2. pytme-0.2.0b0.data/scripts/postprocess.py +570 -0
  3. {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/preprocessor_gui.py +244 -60
  4. {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/METADATA +3 -1
  5. pytme-0.2.0b0.dist-info/RECORD +66 -0
  6. {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +218 -0
  8. scripts/match_template.py +148 -126
  9. scripts/match_template_filters.py +852 -0
  10. scripts/postprocess.py +380 -435
  11. scripts/preprocessor_gui.py +244 -60
  12. scripts/refine_matches.py +218 -0
  13. tme/__init__.py +2 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +545 -78
  16. tme/backends/cupy_backend.py +80 -15
  17. tme/backends/npfftw_backend.py +33 -2
  18. tme/backends/pytorch_backend.py +15 -7
  19. tme/density.py +156 -63
  20. tme/extensions.cpython-311-darwin.so +0 -0
  21. tme/matching_constrained.py +195 -0
  22. tme/matching_data.py +76 -32
  23. tme/matching_exhaustive.py +366 -204
  24. tme/matching_memory.py +1 -0
  25. tme/matching_optimization.py +728 -651
  26. tme/matching_utils.py +152 -8
  27. tme/orientations.py +561 -0
  28. tme/preprocessor.py +21 -18
  29. tme/structure.py +2 -37
  30. pytme-0.1.8.data/scripts/postprocess.py +0 -625
  31. pytme-0.1.8.dist-info/RECORD +0 -61
  32. {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/estimate_ram_usage.py +0 -0
  33. {pytme-0.1.8.data → pytme-0.2.0b0.data}/scripts/preprocess.py +0 -0
  34. {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/LICENSE +0 -0
  35. {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/entry_points.txt +0 -0
  36. {pytme-0.1.8.dist-info → pytme-0.2.0b0.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ from qtpy.QtWidgets import QFileDialog
23
23
  from numpy.typing import NDArray
24
24
 
25
25
  from tme import Preprocessor, Density
26
- from tme.matching_utils import create_mask
26
+ from tme.matching_utils import create_mask, load_pickle
27
27
 
28
28
  preprocessor = Preprocessor()
29
29
  SLIDER_MIN, SLIDER_MAX = 0, 25
@@ -416,7 +416,12 @@ class FilterWidget(widgets.Container):
416
416
 
417
417
 
418
418
  def sphere_mask(
419
- template: NDArray, center_x: float, center_y: float, center_z: float, radius: float
419
+ template: NDArray,
420
+ center_x: float,
421
+ center_y: float,
422
+ center_z: float,
423
+ radius: float,
424
+ **kwargs,
420
425
  ) -> NDArray:
421
426
  return create_mask(
422
427
  mask_type="ellipse",
@@ -434,6 +439,7 @@ def ellipsod_mask(
434
439
  radius_x: float,
435
440
  radius_y: float,
436
441
  radius_z: float,
442
+ **kwargs,
437
443
  ) -> NDArray:
438
444
  return create_mask(
439
445
  mask_type="ellipse",
@@ -451,6 +457,7 @@ def box_mask(
451
457
  height_x: int,
452
458
  height_y: int,
453
459
  height_z: int,
460
+ **kwargs,
454
461
  ) -> NDArray:
455
462
  return create_mask(
456
463
  mask_type="box",
@@ -469,6 +476,7 @@ def tube_mask(
469
476
  inner_radius: float,
470
477
  outer_radius: float,
471
478
  height: int,
479
+ **kwargs,
472
480
  ) -> NDArray:
473
481
  return create_mask(
474
482
  mask_type="tube",
@@ -492,6 +500,7 @@ def wedge_mask(
492
500
  omit_negative_frequencies: bool = False,
493
501
  extrude_plane: bool = True,
494
502
  infinite_plane: bool = True,
503
+ **kwargs,
495
504
  ) -> NDArray:
496
505
  if tilt_step <= 0:
497
506
  wedge_mask = preprocessor.continuous_wedge_mask(
@@ -524,7 +533,7 @@ def wedge_mask(
524
533
 
525
534
 
526
535
  def threshold_mask(
527
- template: NDArray, standard_deviation: float = 5.0, invert: bool = False
536
+ template: NDArray, standard_deviation: float = 5.0, invert: bool = False, **kwargs
528
537
  ) -> NDArray:
529
538
  template_mean = template.mean()
530
539
  template_deviation = standard_deviation * template.std()
@@ -537,7 +546,7 @@ def threshold_mask(
537
546
  return mask
538
547
 
539
548
 
540
- def lowpass_mask(template: NDArray, sigma: float = 1.0):
549
+ def lowpass_mask(template: NDArray, sigma: float = 1.0, **kwargs):
541
550
  template = template / template.max()
542
551
  template = (template > np.exp(-2)) * 128.0
543
552
  template = preprocessor.gaussian_filter(template=template, sigma=sigma)
@@ -546,6 +555,20 @@ def lowpass_mask(template: NDArray, sigma: float = 1.0):
546
555
  return mask
547
556
 
548
557
 
558
+ def shape_mask(template, shapes_layer, expansion_dim):
559
+ ret = np.zeros_like(template)
560
+ mask_shape = tuple(x for i, x in enumerate(template.shape) if i != expansion_dim)
561
+ masks = shapes_layer.to_masks(mask_shape=mask_shape)
562
+ for index, shape_type in enumerate(shapes_layer.shape_type):
563
+ mask = np.expand_dims(masks[index], axis=expansion_dim)
564
+ mask = np.repeat(
565
+ mask, repeats=template.shape[expansion_dim], axis=expansion_dim
566
+ )
567
+ np.logical_or(ret, mask, out=ret)
568
+
569
+ return ret
570
+
571
+
549
572
  class MaskWidget(widgets.Container):
550
573
  def __init__(self, viewer):
551
574
  super().__init__(layout="vertical")
@@ -564,6 +587,7 @@ class MaskWidget(widgets.Container):
564
587
  "Wedge": wedge_mask,
565
588
  "Threshold": threshold_mask,
566
589
  "Lowpass": lowpass_mask,
590
+ "Shape": shape_mask,
567
591
  }
568
592
 
569
593
  self.method_dropdown = widgets.ComboBox(
@@ -581,16 +605,19 @@ class MaskWidget(widgets.Container):
581
605
  self._update_action_button_state
582
606
  )
583
607
 
584
- self.align_button = widgets.PushButton(text="Align to axis", enabled=False)
585
- self.align_button.changed.connect(self._align_with_axis)
586
608
  self.density_field = widgets.Label()
587
609
  # self.density_field.value = f"Positive Density in Mask: {0:.2f}%"
588
610
 
611
+ self.shapes_layer_dropdown = widgets.ComboBox(
612
+ name="shapes_layer", choices=self._get_shape_layers()
613
+ )
614
+ self.viewer.layers.events.inserted.connect(self._update_shape_layer_choices)
615
+ self.viewer.layers.events.removed.connect(self._update_shape_layer_choices)
616
+
589
617
  self.append(self.method_dropdown)
590
618
  self.append(self.adapt_button)
591
619
  self.append(self.percentile_range_edit)
592
620
 
593
- self.append(self.align_button)
594
621
  self.append(self.action_button)
595
622
  self.append(self.density_field)
596
623
 
@@ -598,43 +625,9 @@ class MaskWidget(widgets.Container):
598
625
  self._on_method_changed(None)
599
626
 
600
627
  def _update_action_button_state(self, event):
601
- self.align_button.enabled = bool(self.viewer.layers.selection.active)
602
628
  self.action_button.enabled = bool(self.viewer.layers.selection.active)
603
629
  self.adapt_button.enabled = bool(self.viewer.layers.selection.active)
604
630
 
605
- def _align_with_axis(self):
606
- active_layer = self.viewer.layers.selection.active
607
-
608
- if active_layer.metadata.get("is_aligned", False):
609
- return
610
-
611
- coords = np.array(np.where(active_layer.data > 0)).T
612
- centered_coords = coords - np.mean(coords, axis=0)
613
- cov_matrix = np.cov(centered_coords, rowvar=False)
614
-
615
- eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
616
- principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
617
-
618
- rotation_axis = np.cross(principal_eigenvector, [1, 0, 0])
619
- rotation_angle = np.arccos(np.dot(principal_eigenvector, [1, 0, 0]))
620
- k = rotation_axis / np.linalg.norm(rotation_axis)
621
- K = np.array([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]])
622
- rotation_matrix = np.eye(3)
623
- rotation_matrix += np.sin(rotation_angle) * K
624
- rotation_matrix += (1 - np.cos(rotation_angle)) * np.dot(K, K)
625
-
626
- rotated_data = Density.rotate_array(
627
- arr=active_layer.data,
628
- rotation_matrix=rotation_matrix,
629
- use_geometric_center=False,
630
- order=1,
631
- )
632
- eps = np.finfo(rotated_data.dtype).eps
633
- rotated_data[rotated_data < eps] = 0
634
-
635
- active_layer.metadata["is_aligned"] = True
636
- active_layer.data = rotated_data
637
-
638
631
  def _update_initial_values(self, event=None):
639
632
  active_layer = self.viewer.layers.selection.active
640
633
 
@@ -673,17 +666,43 @@ class MaskWidget(widgets.Container):
673
666
  self.action_widgets.clear()
674
667
 
675
668
  function = self.methods.get(self.method_dropdown.value)
676
- widgets = widgets_from_function(function)
677
- for widget in widgets:
669
+ function_widgets = widgets_from_function(function)
670
+ for widget in function_widgets:
678
671
  self.action_widgets.append(widget)
679
672
  self.insert(1, widget)
680
673
 
674
+ for name, param in inspect.signature(function).parameters.items():
675
+ if name == "shapes_layer":
676
+ self.action_widgets.append(self.shapes_layer_dropdown)
677
+ self.insert(1, self.shapes_layer_dropdown)
678
+
679
+ def _get_shape_layers(self):
680
+ layers = [
681
+ layer.name
682
+ for layer in self.viewer.layers
683
+ if isinstance(layer, napari.layers.Shapes)
684
+ ]
685
+ return layers
686
+
687
+ def _update_shape_layer_choices(self, event):
688
+ """Update the choices in the shapes layer dropdown."""
689
+ self.shapes_layer_dropdown.choices = self._get_shape_layers()
690
+
681
691
  def _action(self):
682
692
  function = self.methods.get(self.method_dropdown.value)
683
693
 
684
694
  selected_layer = self.viewer.layers.selection.active
685
695
  kwargs = {widget.name: widget.value for widget in self.action_widgets}
696
+
697
+ if "shapes_layer" in kwargs:
698
+ layer_name = kwargs["shapes_layer"]
699
+ if layer_name not in self.viewer.layers:
700
+ return None
701
+ kwargs["shapes_layer"] = self.viewer.layers[layer_name]
702
+ kwargs["expansion_dim"] = self.viewer.dims.order[0]
703
+
686
704
  processed_data = function(template=selected_layer.data, **kwargs)
705
+
687
706
  new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
688
707
 
689
708
  if new_layer_name in self.viewer.layers:
@@ -705,14 +724,114 @@ class MaskWidget(widgets.Container):
705
724
  metadata["origin_layer"] = selected_layer.name
706
725
  new_layer.metadata = metadata
707
726
 
708
- origin_layer = metadata["origin_layer"]
709
- if origin_layer in self.viewer.layers:
710
- origin_layer = self.viewer.layers[origin_layer]
711
- if np.allclose(origin_layer.data.shape, processed_data.shape):
712
- in_mask = np.sum(np.fmax(origin_layer.data * processed_data, 0))
713
- in_mask /= np.sum(np.fmax(origin_layer.data, 0))
714
- in_mask *= 100
715
- self.density_field.value = f"Positive Density in Mask: {in_mask:.2f}%"
727
+ if self.method_dropdown.value == "Shape":
728
+ new_layer.metadata = {}
729
+
730
+ # origin_layer = metadata["origin_layer"]
731
+ # if origin_layer in self.viewer.layers:
732
+ # origin_layer = self.viewer.layers[origin_layer]
733
+ # if np.allclose(origin_layer.data.shape, processed_data.shape):
734
+ # in_mask = np.sum(np.fmax(origin_layer.data * processed_data, 0))
735
+ # in_mask /= np.sum(np.fmax(origin_layer.data, 0))
736
+ # in_mask *= 100
737
+ # self.density_field.value = f"Positive Density in Mask: {in_mask:.2f}%"
738
+
739
+
740
+ class AlignmentWidget(widgets.Container):
741
+ def __init__(self, viewer):
742
+ super().__init__(layout="vertical")
743
+
744
+ self.viewer = viewer
745
+
746
+ align_button = widgets.PushButton(text="Align to axis", enabled=True)
747
+ self.align_axis = widgets.ComboBox(
748
+ value=None, nullable=True, choices=self._get_active_layer_dims
749
+ )
750
+ self.viewer.layers.selection.events.changed.connect(self._update_align_axis)
751
+
752
+ align_button.changed.connect(self._align_with_axis)
753
+ container = widgets.Container(
754
+ widgets=[align_button, self.align_axis], layout="horizontal"
755
+ )
756
+ self.append(container)
757
+
758
+ rot90 = widgets.PushButton(text="Rotate 90", enabled=True)
759
+ rotneg90 = widgets.PushButton(text="Rotate -90", enabled=True)
760
+
761
+ rot90.changed.connect(self._rot90)
762
+ rotneg90.changed.connect(self._rotneg90)
763
+
764
+ container = widgets.Container(widgets=[rot90, rotneg90], layout="horizontal")
765
+ self.append(container)
766
+
767
+ def _rot90(self, swap_axes: bool = False):
768
+ active_layer = self.viewer.layers.selection.active
769
+ if active_layer is None:
770
+ return None
771
+ elif self.viewer.dims.ndisplay != 2:
772
+ return None
773
+
774
+ align_axis = self.align_axis.value
775
+ if self.align_axis.value is None:
776
+ align_axis = self.viewer.dims.order[0]
777
+
778
+ axes = [
779
+ align_axis,
780
+ *[i for i in range(len(self.viewer.dims.order)) if i != align_axis],
781
+ ][:2]
782
+ axes = axes[::-1] if swap_axes else axes
783
+ active_layer.data = np.rot90(active_layer.data, k=1, axes=axes)
784
+
785
+ def _rotneg90(self):
786
+ return self._rot90(swap_axes=True)
787
+
788
+ def _get_active_layer_dims(self, *args):
789
+ active_layer = self.viewer.layers.selection.active
790
+ if active_layer is None:
791
+ return ()
792
+ return [i for i in range(active_layer.data.ndim)]
793
+
794
+ def _update_align_axis(self, *args):
795
+ self.align_axis.choices = self._get_active_layer_dims()
796
+
797
+ def _align_with_axis(self):
798
+ active_layer = self.viewer.layers.selection.active
799
+
800
+ if self.align_axis.value is None:
801
+ return None
802
+
803
+ if active_layer.metadata.get("is_aligned", None) == self.align_axis.value:
804
+ return None
805
+
806
+ alignment_axis = np.zeros(active_layer.data.ndim)
807
+ alignment_axis[int(self.align_axis.value)] = 1
808
+
809
+ coords = np.array(np.where(active_layer.data > 0)).T
810
+ centered_coords = coords - np.mean(coords, axis=0)
811
+ cov_matrix = np.cov(centered_coords, rowvar=False)
812
+
813
+ eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
814
+ principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
815
+
816
+ rotation_axis = np.cross(principal_eigenvector, alignment_axis)
817
+ rotation_angle = np.arccos(np.dot(principal_eigenvector, alignment_axis))
818
+ k = rotation_axis / np.linalg.norm(rotation_axis)
819
+ K = np.array([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]])
820
+ rotation_matrix = np.eye(3)
821
+ rotation_matrix += np.sin(rotation_angle) * K
822
+ rotation_matrix += (1 - np.cos(rotation_angle)) * np.dot(K, K)
823
+
824
+ rotated_data = Density.rotate_array(
825
+ arr=active_layer.data,
826
+ rotation_matrix=rotation_matrix,
827
+ use_geometric_center=False,
828
+ order=1,
829
+ )
830
+ eps = np.finfo(rotated_data.dtype).eps
831
+ rotated_data[rotated_data < eps] = 0
832
+
833
+ active_layer.metadata["is_aligned"] = int(self.align_axis.value)
834
+ active_layer.data = rotated_data
716
835
 
717
836
 
718
837
  class ExportWidget(widgets.Container):
@@ -803,16 +922,36 @@ class PointCloudWidget(widgets.Container):
803
922
  options=options,
804
923
  )
805
924
 
806
- if filename:
807
- layer = self.viewer.layers.selection.active
808
- if layer and isinstance(layer, napari.layers.Points):
809
- original_dataframe = self.dataframes.get(layer.name, pd.DataFrame())
925
+ if not filename:
926
+ return None
810
927
 
811
- export_data = pd.DataFrame(layer.data, columns=["z", "y", "x"])
812
- merged_data = pd.merge(
813
- export_data, original_dataframe, on=["z", "y", "x"], how="left"
814
- )
815
- merged_data.to_csv(filename, sep="\t", index=False)
928
+ layer = self.viewer.layers.selection.active
929
+ if layer and isinstance(layer, napari.layers.Points):
930
+ original_dataframe = self.dataframes.get(
931
+ layer.name, pd.DataFrame(columns=["z", "y", "x"])
932
+ )
933
+
934
+ export_data = pd.DataFrame(layer.data, columns=["z", "y", "x"])
935
+ merged_data = pd.merge(
936
+ export_data, original_dataframe, on=["z", "y", "x"], how="left"
937
+ )
938
+
939
+ merged_data["z"] = merged_data["z"].astype(int)
940
+ merged_data["y"] = merged_data["y"].astype(int)
941
+ merged_data["x"] = merged_data["x"].astype(int)
942
+
943
+ euler_columns = ["euler_z", "euler_y", "euler_x"]
944
+ for col in euler_columns:
945
+ if col not in merged_data.columns:
946
+ continue
947
+ merged_data[col] = merged_data[col].fillna(0)
948
+
949
+ if "score" in merged_data.columns:
950
+ merged_data["score"] = merged_data["score"].fillna(1)
951
+ if "detail" in merged_data.columns:
952
+ merged_data["detail"] = merged_data["detail"].fillna(2)
953
+
954
+ merged_data.to_csv(filename, sep="\t", index=False)
816
955
 
817
956
  def _get_load_path(self, event):
818
957
  options = QFileDialog.Options()
@@ -830,6 +969,13 @@ class PointCloudWidget(widgets.Container):
830
969
  dataframe = pd.read_csv(filename, sep="\t")
831
970
  points = dataframe[["z", "y", "x"]].values
832
971
  layer_name = filename.split("/")[-1]
972
+
973
+ if "score" not in dataframe.columns:
974
+ dataframe["score"] = 1
975
+
976
+ if "detail" not in dataframe.columns:
977
+ dataframe["detail"] = -2
978
+
833
979
  point_properties = {
834
980
  "score": np.array(dataframe["score"].values),
835
981
  "detail": np.array(dataframe["detail"].values),
@@ -849,6 +995,38 @@ class PointCloudWidget(widgets.Container):
849
995
  self.dataframes[layer_name] = dataframe
850
996
 
851
997
 
998
+ class MatchingWidget(widgets.Container):
999
+ def __init__(self, viewer):
1000
+ super().__init__(layout="vertical")
1001
+
1002
+ self.viewer = viewer
1003
+ self.dataframes = {}
1004
+
1005
+ self.import_button = widgets.PushButton(name="Import", text="Import Pickle")
1006
+ self.import_button.clicked.connect(self._get_load_path)
1007
+
1008
+ self.append(self.import_button)
1009
+
1010
+ def _get_load_path(self, event):
1011
+ options = QFileDialog.Options()
1012
+ filename, _ = QFileDialog.getOpenFileName(
1013
+ self.native,
1014
+ "Open Pickle File...",
1015
+ "",
1016
+ "Pickle Files (*.pickle);;All Files (*)",
1017
+ options=options,
1018
+ )
1019
+ if filename:
1020
+ self._load_data(filename)
1021
+
1022
+ def _load_data(self, filename):
1023
+ data = load_pickle(filename)
1024
+
1025
+ _ = self.viewer.add_image(data=data[2], name="Rotations", colormap="orange")
1026
+
1027
+ _ = self.viewer.add_image(data=data[0], name="Scores", colormap="turbo")
1028
+
1029
+
852
1030
  def main():
853
1031
  viewer = napari.Viewer()
854
1032
 
@@ -856,10 +1034,16 @@ def main():
856
1034
  mask_widget = MaskWidget(viewer)
857
1035
  export_widget = ExportWidget(viewer)
858
1036
  point_cloud = PointCloudWidget(viewer)
1037
+ matching_widget = MatchingWidget(viewer)
1038
+ alignment_widget = AlignmentWidget(viewer)
859
1039
 
860
1040
  viewer.window.add_dock_widget(widget=filter_widget, name="Preprocess", area="right")
1041
+ viewer.window.add_dock_widget(
1042
+ widget=alignment_widget, name="Alignment", area="right"
1043
+ )
861
1044
  viewer.window.add_dock_widget(widget=mask_widget, name="Mask", area="right")
862
1045
  viewer.window.add_dock_widget(widget=point_cloud, name="PointCloud", area="left")
1046
+ viewer.window.add_dock_widget(widget=matching_widget, name="Matching", area="left")
863
1047
 
864
1048
  viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
865
1049
 
@@ -0,0 +1,218 @@
1
+ #!python3
2
+ """ CLI to refine template matching candidates.
3
+
4
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ import argparse
9
+ from time import time
10
+
11
+ import numpy as np
12
+ from numpy.typing import NDArray
13
+
14
+ from tme.backends import backend
15
+ from tme import Density, Structure
16
+ from tme.matching_data import MatchingData
17
+ from tme.analyzer import MaxScoreOverRotations, MaxScoreOverTranslations
18
+ from tme.matching_utils import (
19
+ load_pickle,
20
+ get_rotation_matrices,
21
+ compute_parallelization_schedule,
22
+ )
23
+ from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
24
+
25
+ from postprocess import Orientations
26
+ from match_template import load_and_validate_mask
27
+
28
+ def parse_args():
29
+ parser = argparse.ArgumentParser(
30
+ description="Refine Template Matching Orientations."
31
+ )
32
+ parser.add_argument(
33
+ "--input_file",
34
+ required=True,
35
+ type=str,
36
+ help="Path to the output of match_template.py.",
37
+ )
38
+ parser.add_argument(
39
+ "--orientations",
40
+ required=True,
41
+ type=str,
42
+ help="Path to orientations from postprocess.py.",
43
+ )
44
+ parser.add_argument(
45
+ "--output_file",
46
+ required=True,
47
+ help="Path to the refined output orientations.",
48
+ )
49
+ parser.add_argument(
50
+ "-n",
51
+ dest="cores",
52
+ required=False,
53
+ type=int,
54
+ default=4,
55
+ help="Number of cores used for template matching.",
56
+ )
57
+
58
+ args = parser.parse_args()
59
+
60
+ return args
61
+
62
+
63
+ def load_template(filepath: str, sampling_rate: NDArray) -> "Density":
64
+ try:
65
+ template = Density.from_file(filepath)
66
+ except ValueError:
67
+ template = Structure.from_file(filepath)
68
+ template = Density.from_structure(template, sampling_rate=sampling_rate)
69
+
70
+ return template
71
+
72
+
73
+ def main():
74
+ args = parse_args()
75
+ meta = load_pickle(args.input_file)[-1]
76
+ target_origin, _, sampling_rate, cli_args = meta
77
+
78
+ orientations = Orientations.from_file(
79
+ filename=args.orientations, file_format="text"
80
+ )
81
+
82
+ template = load_template(cli_args.template, sampling_rate)
83
+ template_mask = load_and_validate_mask(
84
+ mask_target=template, mask_path=cli_args.template_mask
85
+ )
86
+
87
+ if not cli_args.no_centering:
88
+ template, translation = template.centered(0)
89
+
90
+ if template_mask is None:
91
+ template_mask = template.empty
92
+ if not cli_args.no_centering:
93
+ enclosing_box = template.minimum_enclosing_box(
94
+ 0, use_geometric_center=False
95
+ )
96
+ template_mask.adjust_box(enclosing_box)
97
+
98
+ template_mask.data[:] = 1
99
+ translation = np.zeros_like(translation)
100
+
101
+ template_mask.pad(template.shape, center=False)
102
+ origin_translation = np.divide(
103
+ np.subtract(template.origin, template_mask.origin), template.sampling_rate
104
+ )
105
+ translation = np.add(translation, origin_translation)
106
+
107
+ template_mask = template_mask.rigid_transform(
108
+ rotation_matrix=np.eye(template_mask.data.ndim),
109
+ translation=-translation,
110
+ order=1,
111
+ )
112
+ template_mask.origin = template.origin.copy()
113
+
114
+ target = Density.from_file(cli_args.target)
115
+ peaks = orientations.translations.astype(int)
116
+ half_shape = np.divide(template.shape, 2).astype(int)
117
+ observation_starts = np.subtract(peaks, half_shape)
118
+ observation_stops = np.add(peaks, half_shape) + np.mod(template.shape, 2).astype(
119
+ int
120
+ )
121
+
122
+ pruned_starts = np.maximum(observation_starts, 0)
123
+ pruned_stops = np.minimum(observation_stops, target.shape)
124
+
125
+ keep_peaks = (
126
+ np.sum(
127
+ np.multiply(
128
+ observation_starts == pruned_starts, observation_stops == pruned_stops
129
+ ),
130
+ axis=1,
131
+ )
132
+ == observation_starts.shape[1]
133
+ )
134
+ observation_starts = observation_starts[keep_peaks]
135
+ observation_stops = observation_stops[keep_peaks]
136
+
137
+ observation_slices = [
138
+ tuple(slice(s, e) for s, e in zip(start_row, stop_row))
139
+ for start_row, stop_row in zip(observation_starts, observation_stops)
140
+ ]
141
+
142
+ matching_data = MatchingData(target=target, template=template)
143
+ matching_data.rotations = np.eye(template.data.ndim).reshape(1, 3, 3)
144
+
145
+ target_pad = matching_data.target_padding(pad_target=True)
146
+ out_shape = np.add(target_pad, template.shape).astype(int)
147
+
148
+ observations = np.zeros((len(observation_slices), *out_shape))
149
+
150
+
151
+ for idx, obs_slice in enumerate(observation_slices):
152
+ subset = matching_data.subset_by_slice(
153
+ target_slice=obs_slice,
154
+ target_pad=target_pad,
155
+ invert_target=cli_args.invert_target_contrast,
156
+ )
157
+ xd = template.copy()
158
+ xd.pad(subset.target.shape, center = True)
159
+ # observations[idx] = subset.target
160
+ observations[idx] = xd.data
161
+
162
+
163
+ matching_data = MatchingData(target=observations, template=template)
164
+ matching_data._set_batch_dimension(target_dims=0, template_dims=None)
165
+ matching_data.rotations = get_rotation_matrices(15)
166
+ if template_mask is not None:
167
+ matching_data.template_mask = template_mask.data
168
+
169
+ template_box = np.ones(len(matching_data._output_template_shape), dtype=int)
170
+ target_padding = np.zeros_like(matching_data._output_target_shape)
171
+
172
+ scoring_method = "FLC"
173
+ callback_class = MaxScoreOverRotations
174
+ splits, schedule = compute_parallelization_schedule(
175
+ shape1=matching_data._output_target_shape,
176
+ shape2=template_box,
177
+ shape1_padding=target_padding,
178
+ max_cores=args.cores,
179
+ max_ram=backend.get_available_memory(),
180
+ split_only_outer=False,
181
+ matching_method=scoring_method,
182
+ analyzer_method=callback_class.__name__,
183
+ backend=backend._backend_name,
184
+ float_nbytes=backend.datatype_bytes(backend._default_dtype),
185
+ complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
186
+ integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
187
+ )
188
+
189
+ matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[scoring_method]
190
+
191
+ start = time()
192
+ candidates = scan_subsets(
193
+ matching_data=matching_data,
194
+ matching_score=matching_score,
195
+ matching_setup=matching_setup,
196
+ callback_class=callback_class,
197
+ callback_class_args={
198
+ # "score_space_shape" : (
199
+ # matching_data.rotations.shape[0], observations.shape[0]
200
+ # ),
201
+ # "score_space_dtype" : backend._default_dtype,
202
+ # "template_shape" : (matching_data.rotations.shape[0], *matching_data._template.shape)
203
+ },
204
+ target_splits=splits,
205
+ job_schedule=schedule,
206
+ pad_target_edges=False,
207
+ pad_fourier=False,
208
+ interpolation_order=cli_args.interpolation_order,
209
+ )
210
+ print(candidates[0].max())
211
+ Density(candidates[0][0]).to_file("scores.mrc")
212
+
213
+ runtime = time() - start
214
+ print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
215
+
216
+
217
+ if __name__ == "__main__":
218
+ main()
tme/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
+ from . import extensions
1
2
  from .__version__ import __version__
2
3
  from .density import Density
3
4
  from .preprocessor import Preprocessor
4
5
  from .structure import Structure
6
+ from .orientations import Orientations
5
7
  from .matching_optimization import FitRefinement
6
- from . import extensions
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.8"
1
+ __version__ = "0.2.0b"