pytme 0.1.9__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.0.data/scripts/match_template.py +1019 -0
- pytme-0.2.0.data/scripts/postprocess.py +570 -0
- {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocessor_gui.py +244 -60
- {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/METADATA +3 -1
- pytme-0.2.0.dist-info/RECORD +72 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +218 -0
- scripts/match_template.py +459 -218
- pytme-0.1.9.data/scripts/match_template.py → scripts/match_template_filters.py +459 -218
- scripts/postprocess.py +380 -435
- scripts/preprocessor_gui.py +244 -60
- scripts/refine_matches.py +218 -0
- tme/__init__.py +2 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +533 -78
- tme/backends/cupy_backend.py +80 -15
- tme/backends/npfftw_backend.py +35 -6
- tme/backends/pytorch_backend.py +15 -7
- tme/density.py +173 -78
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_constrained.py +195 -0
- tme/matching_data.py +76 -33
- tme/matching_exhaustive.py +354 -225
- tme/matching_memory.py +1 -0
- tme/matching_optimization.py +753 -649
- tme/matching_utils.py +152 -8
- tme/orientations.py +561 -0
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +176 -0
- tme/preprocessing/composable_filter.py +30 -0
- tme/preprocessing/compose.py +52 -0
- tme/preprocessing/frequency_filters.py +322 -0
- tme/preprocessing/tilt_series.py +967 -0
- tme/preprocessor.py +35 -25
- tme/structure.py +2 -37
- pytme-0.1.9.data/scripts/postprocess.py +0 -625
- pytme-0.1.9.dist-info/RECORD +0 -61
- {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocess.py +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/LICENSE +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/entry_points.txt +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/top_level.txt +0 -0
scripts/preprocessor_gui.py
CHANGED
@@ -23,7 +23,7 @@ from qtpy.QtWidgets import QFileDialog
|
|
23
23
|
from numpy.typing import NDArray
|
24
24
|
|
25
25
|
from tme import Preprocessor, Density
|
26
|
-
from tme.matching_utils import create_mask
|
26
|
+
from tme.matching_utils import create_mask, load_pickle
|
27
27
|
|
28
28
|
preprocessor = Preprocessor()
|
29
29
|
SLIDER_MIN, SLIDER_MAX = 0, 25
|
@@ -416,7 +416,12 @@ class FilterWidget(widgets.Container):
|
|
416
416
|
|
417
417
|
|
418
418
|
def sphere_mask(
|
419
|
-
template: NDArray,
|
419
|
+
template: NDArray,
|
420
|
+
center_x: float,
|
421
|
+
center_y: float,
|
422
|
+
center_z: float,
|
423
|
+
radius: float,
|
424
|
+
**kwargs,
|
420
425
|
) -> NDArray:
|
421
426
|
return create_mask(
|
422
427
|
mask_type="ellipse",
|
@@ -434,6 +439,7 @@ def ellipsod_mask(
|
|
434
439
|
radius_x: float,
|
435
440
|
radius_y: float,
|
436
441
|
radius_z: float,
|
442
|
+
**kwargs,
|
437
443
|
) -> NDArray:
|
438
444
|
return create_mask(
|
439
445
|
mask_type="ellipse",
|
@@ -451,6 +457,7 @@ def box_mask(
|
|
451
457
|
height_x: int,
|
452
458
|
height_y: int,
|
453
459
|
height_z: int,
|
460
|
+
**kwargs,
|
454
461
|
) -> NDArray:
|
455
462
|
return create_mask(
|
456
463
|
mask_type="box",
|
@@ -469,6 +476,7 @@ def tube_mask(
|
|
469
476
|
inner_radius: float,
|
470
477
|
outer_radius: float,
|
471
478
|
height: int,
|
479
|
+
**kwargs,
|
472
480
|
) -> NDArray:
|
473
481
|
return create_mask(
|
474
482
|
mask_type="tube",
|
@@ -492,6 +500,7 @@ def wedge_mask(
|
|
492
500
|
omit_negative_frequencies: bool = False,
|
493
501
|
extrude_plane: bool = True,
|
494
502
|
infinite_plane: bool = True,
|
503
|
+
**kwargs,
|
495
504
|
) -> NDArray:
|
496
505
|
if tilt_step <= 0:
|
497
506
|
wedge_mask = preprocessor.continuous_wedge_mask(
|
@@ -524,7 +533,7 @@ def wedge_mask(
|
|
524
533
|
|
525
534
|
|
526
535
|
def threshold_mask(
|
527
|
-
template: NDArray, standard_deviation: float = 5.0, invert: bool = False
|
536
|
+
template: NDArray, standard_deviation: float = 5.0, invert: bool = False, **kwargs
|
528
537
|
) -> NDArray:
|
529
538
|
template_mean = template.mean()
|
530
539
|
template_deviation = standard_deviation * template.std()
|
@@ -537,7 +546,7 @@ def threshold_mask(
|
|
537
546
|
return mask
|
538
547
|
|
539
548
|
|
540
|
-
def lowpass_mask(template: NDArray, sigma: float = 1.0):
|
549
|
+
def lowpass_mask(template: NDArray, sigma: float = 1.0, **kwargs):
|
541
550
|
template = template / template.max()
|
542
551
|
template = (template > np.exp(-2)) * 128.0
|
543
552
|
template = preprocessor.gaussian_filter(template=template, sigma=sigma)
|
@@ -546,6 +555,20 @@ def lowpass_mask(template: NDArray, sigma: float = 1.0):
|
|
546
555
|
return mask
|
547
556
|
|
548
557
|
|
558
|
+
def shape_mask(template, shapes_layer, expansion_dim):
|
559
|
+
ret = np.zeros_like(template)
|
560
|
+
mask_shape = tuple(x for i, x in enumerate(template.shape) if i != expansion_dim)
|
561
|
+
masks = shapes_layer.to_masks(mask_shape=mask_shape)
|
562
|
+
for index, shape_type in enumerate(shapes_layer.shape_type):
|
563
|
+
mask = np.expand_dims(masks[index], axis=expansion_dim)
|
564
|
+
mask = np.repeat(
|
565
|
+
mask, repeats=template.shape[expansion_dim], axis=expansion_dim
|
566
|
+
)
|
567
|
+
np.logical_or(ret, mask, out=ret)
|
568
|
+
|
569
|
+
return ret
|
570
|
+
|
571
|
+
|
549
572
|
class MaskWidget(widgets.Container):
|
550
573
|
def __init__(self, viewer):
|
551
574
|
super().__init__(layout="vertical")
|
@@ -564,6 +587,7 @@ class MaskWidget(widgets.Container):
|
|
564
587
|
"Wedge": wedge_mask,
|
565
588
|
"Threshold": threshold_mask,
|
566
589
|
"Lowpass": lowpass_mask,
|
590
|
+
"Shape": shape_mask,
|
567
591
|
}
|
568
592
|
|
569
593
|
self.method_dropdown = widgets.ComboBox(
|
@@ -581,16 +605,19 @@ class MaskWidget(widgets.Container):
|
|
581
605
|
self._update_action_button_state
|
582
606
|
)
|
583
607
|
|
584
|
-
self.align_button = widgets.PushButton(text="Align to axis", enabled=False)
|
585
|
-
self.align_button.changed.connect(self._align_with_axis)
|
586
608
|
self.density_field = widgets.Label()
|
587
609
|
# self.density_field.value = f"Positive Density in Mask: {0:.2f}%"
|
588
610
|
|
611
|
+
self.shapes_layer_dropdown = widgets.ComboBox(
|
612
|
+
name="shapes_layer", choices=self._get_shape_layers()
|
613
|
+
)
|
614
|
+
self.viewer.layers.events.inserted.connect(self._update_shape_layer_choices)
|
615
|
+
self.viewer.layers.events.removed.connect(self._update_shape_layer_choices)
|
616
|
+
|
589
617
|
self.append(self.method_dropdown)
|
590
618
|
self.append(self.adapt_button)
|
591
619
|
self.append(self.percentile_range_edit)
|
592
620
|
|
593
|
-
self.append(self.align_button)
|
594
621
|
self.append(self.action_button)
|
595
622
|
self.append(self.density_field)
|
596
623
|
|
@@ -598,43 +625,9 @@ class MaskWidget(widgets.Container):
|
|
598
625
|
self._on_method_changed(None)
|
599
626
|
|
600
627
|
def _update_action_button_state(self, event):
|
601
|
-
self.align_button.enabled = bool(self.viewer.layers.selection.active)
|
602
628
|
self.action_button.enabled = bool(self.viewer.layers.selection.active)
|
603
629
|
self.adapt_button.enabled = bool(self.viewer.layers.selection.active)
|
604
630
|
|
605
|
-
def _align_with_axis(self):
|
606
|
-
active_layer = self.viewer.layers.selection.active
|
607
|
-
|
608
|
-
if active_layer.metadata.get("is_aligned", False):
|
609
|
-
return
|
610
|
-
|
611
|
-
coords = np.array(np.where(active_layer.data > 0)).T
|
612
|
-
centered_coords = coords - np.mean(coords, axis=0)
|
613
|
-
cov_matrix = np.cov(centered_coords, rowvar=False)
|
614
|
-
|
615
|
-
eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
|
616
|
-
principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
|
617
|
-
|
618
|
-
rotation_axis = np.cross(principal_eigenvector, [1, 0, 0])
|
619
|
-
rotation_angle = np.arccos(np.dot(principal_eigenvector, [1, 0, 0]))
|
620
|
-
k = rotation_axis / np.linalg.norm(rotation_axis)
|
621
|
-
K = np.array([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]])
|
622
|
-
rotation_matrix = np.eye(3)
|
623
|
-
rotation_matrix += np.sin(rotation_angle) * K
|
624
|
-
rotation_matrix += (1 - np.cos(rotation_angle)) * np.dot(K, K)
|
625
|
-
|
626
|
-
rotated_data = Density.rotate_array(
|
627
|
-
arr=active_layer.data,
|
628
|
-
rotation_matrix=rotation_matrix,
|
629
|
-
use_geometric_center=False,
|
630
|
-
order=1,
|
631
|
-
)
|
632
|
-
eps = np.finfo(rotated_data.dtype).eps
|
633
|
-
rotated_data[rotated_data < eps] = 0
|
634
|
-
|
635
|
-
active_layer.metadata["is_aligned"] = True
|
636
|
-
active_layer.data = rotated_data
|
637
|
-
|
638
631
|
def _update_initial_values(self, event=None):
|
639
632
|
active_layer = self.viewer.layers.selection.active
|
640
633
|
|
@@ -673,17 +666,43 @@ class MaskWidget(widgets.Container):
|
|
673
666
|
self.action_widgets.clear()
|
674
667
|
|
675
668
|
function = self.methods.get(self.method_dropdown.value)
|
676
|
-
|
677
|
-
for widget in
|
669
|
+
function_widgets = widgets_from_function(function)
|
670
|
+
for widget in function_widgets:
|
678
671
|
self.action_widgets.append(widget)
|
679
672
|
self.insert(1, widget)
|
680
673
|
|
674
|
+
for name, param in inspect.signature(function).parameters.items():
|
675
|
+
if name == "shapes_layer":
|
676
|
+
self.action_widgets.append(self.shapes_layer_dropdown)
|
677
|
+
self.insert(1, self.shapes_layer_dropdown)
|
678
|
+
|
679
|
+
def _get_shape_layers(self):
|
680
|
+
layers = [
|
681
|
+
layer.name
|
682
|
+
for layer in self.viewer.layers
|
683
|
+
if isinstance(layer, napari.layers.Shapes)
|
684
|
+
]
|
685
|
+
return layers
|
686
|
+
|
687
|
+
def _update_shape_layer_choices(self, event):
|
688
|
+
"""Update the choices in the shapes layer dropdown."""
|
689
|
+
self.shapes_layer_dropdown.choices = self._get_shape_layers()
|
690
|
+
|
681
691
|
def _action(self):
|
682
692
|
function = self.methods.get(self.method_dropdown.value)
|
683
693
|
|
684
694
|
selected_layer = self.viewer.layers.selection.active
|
685
695
|
kwargs = {widget.name: widget.value for widget in self.action_widgets}
|
696
|
+
|
697
|
+
if "shapes_layer" in kwargs:
|
698
|
+
layer_name = kwargs["shapes_layer"]
|
699
|
+
if layer_name not in self.viewer.layers:
|
700
|
+
return None
|
701
|
+
kwargs["shapes_layer"] = self.viewer.layers[layer_name]
|
702
|
+
kwargs["expansion_dim"] = self.viewer.dims.order[0]
|
703
|
+
|
686
704
|
processed_data = function(template=selected_layer.data, **kwargs)
|
705
|
+
|
687
706
|
new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
|
688
707
|
|
689
708
|
if new_layer_name in self.viewer.layers:
|
@@ -705,14 +724,114 @@ class MaskWidget(widgets.Container):
|
|
705
724
|
metadata["origin_layer"] = selected_layer.name
|
706
725
|
new_layer.metadata = metadata
|
707
726
|
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
727
|
+
if self.method_dropdown.value == "Shape":
|
728
|
+
new_layer.metadata = {}
|
729
|
+
|
730
|
+
# origin_layer = metadata["origin_layer"]
|
731
|
+
# if origin_layer in self.viewer.layers:
|
732
|
+
# origin_layer = self.viewer.layers[origin_layer]
|
733
|
+
# if np.allclose(origin_layer.data.shape, processed_data.shape):
|
734
|
+
# in_mask = np.sum(np.fmax(origin_layer.data * processed_data, 0))
|
735
|
+
# in_mask /= np.sum(np.fmax(origin_layer.data, 0))
|
736
|
+
# in_mask *= 100
|
737
|
+
# self.density_field.value = f"Positive Density in Mask: {in_mask:.2f}%"
|
738
|
+
|
739
|
+
|
740
|
+
class AlignmentWidget(widgets.Container):
|
741
|
+
def __init__(self, viewer):
|
742
|
+
super().__init__(layout="vertical")
|
743
|
+
|
744
|
+
self.viewer = viewer
|
745
|
+
|
746
|
+
align_button = widgets.PushButton(text="Align to axis", enabled=True)
|
747
|
+
self.align_axis = widgets.ComboBox(
|
748
|
+
value=None, nullable=True, choices=self._get_active_layer_dims
|
749
|
+
)
|
750
|
+
self.viewer.layers.selection.events.changed.connect(self._update_align_axis)
|
751
|
+
|
752
|
+
align_button.changed.connect(self._align_with_axis)
|
753
|
+
container = widgets.Container(
|
754
|
+
widgets=[align_button, self.align_axis], layout="horizontal"
|
755
|
+
)
|
756
|
+
self.append(container)
|
757
|
+
|
758
|
+
rot90 = widgets.PushButton(text="Rotate 90", enabled=True)
|
759
|
+
rotneg90 = widgets.PushButton(text="Rotate -90", enabled=True)
|
760
|
+
|
761
|
+
rot90.changed.connect(self._rot90)
|
762
|
+
rotneg90.changed.connect(self._rotneg90)
|
763
|
+
|
764
|
+
container = widgets.Container(widgets=[rot90, rotneg90], layout="horizontal")
|
765
|
+
self.append(container)
|
766
|
+
|
767
|
+
def _rot90(self, swap_axes: bool = False):
|
768
|
+
active_layer = self.viewer.layers.selection.active
|
769
|
+
if active_layer is None:
|
770
|
+
return None
|
771
|
+
elif self.viewer.dims.ndisplay != 2:
|
772
|
+
return None
|
773
|
+
|
774
|
+
align_axis = self.align_axis.value
|
775
|
+
if self.align_axis.value is None:
|
776
|
+
align_axis = self.viewer.dims.order[0]
|
777
|
+
|
778
|
+
axes = [
|
779
|
+
align_axis,
|
780
|
+
*[i for i in range(len(self.viewer.dims.order)) if i != align_axis],
|
781
|
+
][:2]
|
782
|
+
axes = axes[::-1] if swap_axes else axes
|
783
|
+
active_layer.data = np.rot90(active_layer.data, k=1, axes=axes)
|
784
|
+
|
785
|
+
def _rotneg90(self):
|
786
|
+
return self._rot90(swap_axes=True)
|
787
|
+
|
788
|
+
def _get_active_layer_dims(self, *args):
|
789
|
+
active_layer = self.viewer.layers.selection.active
|
790
|
+
if active_layer is None:
|
791
|
+
return ()
|
792
|
+
return [i for i in range(active_layer.data.ndim)]
|
793
|
+
|
794
|
+
def _update_align_axis(self, *args):
|
795
|
+
self.align_axis.choices = self._get_active_layer_dims()
|
796
|
+
|
797
|
+
def _align_with_axis(self):
|
798
|
+
active_layer = self.viewer.layers.selection.active
|
799
|
+
|
800
|
+
if self.align_axis.value is None:
|
801
|
+
return None
|
802
|
+
|
803
|
+
if active_layer.metadata.get("is_aligned", None) == self.align_axis.value:
|
804
|
+
return None
|
805
|
+
|
806
|
+
alignment_axis = np.zeros(active_layer.data.ndim)
|
807
|
+
alignment_axis[int(self.align_axis.value)] = 1
|
808
|
+
|
809
|
+
coords = np.array(np.where(active_layer.data > 0)).T
|
810
|
+
centered_coords = coords - np.mean(coords, axis=0)
|
811
|
+
cov_matrix = np.cov(centered_coords, rowvar=False)
|
812
|
+
|
813
|
+
eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
|
814
|
+
principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
|
815
|
+
|
816
|
+
rotation_axis = np.cross(principal_eigenvector, alignment_axis)
|
817
|
+
rotation_angle = np.arccos(np.dot(principal_eigenvector, alignment_axis))
|
818
|
+
k = rotation_axis / np.linalg.norm(rotation_axis)
|
819
|
+
K = np.array([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]])
|
820
|
+
rotation_matrix = np.eye(3)
|
821
|
+
rotation_matrix += np.sin(rotation_angle) * K
|
822
|
+
rotation_matrix += (1 - np.cos(rotation_angle)) * np.dot(K, K)
|
823
|
+
|
824
|
+
rotated_data = Density.rotate_array(
|
825
|
+
arr=active_layer.data,
|
826
|
+
rotation_matrix=rotation_matrix,
|
827
|
+
use_geometric_center=False,
|
828
|
+
order=1,
|
829
|
+
)
|
830
|
+
eps = np.finfo(rotated_data.dtype).eps
|
831
|
+
rotated_data[rotated_data < eps] = 0
|
832
|
+
|
833
|
+
active_layer.metadata["is_aligned"] = int(self.align_axis.value)
|
834
|
+
active_layer.data = rotated_data
|
716
835
|
|
717
836
|
|
718
837
|
class ExportWidget(widgets.Container):
|
@@ -803,16 +922,36 @@ class PointCloudWidget(widgets.Container):
|
|
803
922
|
options=options,
|
804
923
|
)
|
805
924
|
|
806
|
-
if filename:
|
807
|
-
|
808
|
-
if layer and isinstance(layer, napari.layers.Points):
|
809
|
-
original_dataframe = self.dataframes.get(layer.name, pd.DataFrame())
|
925
|
+
if not filename:
|
926
|
+
return None
|
810
927
|
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
)
|
815
|
-
|
928
|
+
layer = self.viewer.layers.selection.active
|
929
|
+
if layer and isinstance(layer, napari.layers.Points):
|
930
|
+
original_dataframe = self.dataframes.get(
|
931
|
+
layer.name, pd.DataFrame(columns=["z", "y", "x"])
|
932
|
+
)
|
933
|
+
|
934
|
+
export_data = pd.DataFrame(layer.data, columns=["z", "y", "x"])
|
935
|
+
merged_data = pd.merge(
|
936
|
+
export_data, original_dataframe, on=["z", "y", "x"], how="left"
|
937
|
+
)
|
938
|
+
|
939
|
+
merged_data["z"] = merged_data["z"].astype(int)
|
940
|
+
merged_data["y"] = merged_data["y"].astype(int)
|
941
|
+
merged_data["x"] = merged_data["x"].astype(int)
|
942
|
+
|
943
|
+
euler_columns = ["euler_z", "euler_y", "euler_x"]
|
944
|
+
for col in euler_columns:
|
945
|
+
if col not in merged_data.columns:
|
946
|
+
continue
|
947
|
+
merged_data[col] = merged_data[col].fillna(0)
|
948
|
+
|
949
|
+
if "score" in merged_data.columns:
|
950
|
+
merged_data["score"] = merged_data["score"].fillna(1)
|
951
|
+
if "detail" in merged_data.columns:
|
952
|
+
merged_data["detail"] = merged_data["detail"].fillna(2)
|
953
|
+
|
954
|
+
merged_data.to_csv(filename, sep="\t", index=False)
|
816
955
|
|
817
956
|
def _get_load_path(self, event):
|
818
957
|
options = QFileDialog.Options()
|
@@ -830,6 +969,13 @@ class PointCloudWidget(widgets.Container):
|
|
830
969
|
dataframe = pd.read_csv(filename, sep="\t")
|
831
970
|
points = dataframe[["z", "y", "x"]].values
|
832
971
|
layer_name = filename.split("/")[-1]
|
972
|
+
|
973
|
+
if "score" not in dataframe.columns:
|
974
|
+
dataframe["score"] = 1
|
975
|
+
|
976
|
+
if "detail" not in dataframe.columns:
|
977
|
+
dataframe["detail"] = -2
|
978
|
+
|
833
979
|
point_properties = {
|
834
980
|
"score": np.array(dataframe["score"].values),
|
835
981
|
"detail": np.array(dataframe["detail"].values),
|
@@ -849,6 +995,38 @@ class PointCloudWidget(widgets.Container):
|
|
849
995
|
self.dataframes[layer_name] = dataframe
|
850
996
|
|
851
997
|
|
998
|
+
class MatchingWidget(widgets.Container):
|
999
|
+
def __init__(self, viewer):
|
1000
|
+
super().__init__(layout="vertical")
|
1001
|
+
|
1002
|
+
self.viewer = viewer
|
1003
|
+
self.dataframes = {}
|
1004
|
+
|
1005
|
+
self.import_button = widgets.PushButton(name="Import", text="Import Pickle")
|
1006
|
+
self.import_button.clicked.connect(self._get_load_path)
|
1007
|
+
|
1008
|
+
self.append(self.import_button)
|
1009
|
+
|
1010
|
+
def _get_load_path(self, event):
|
1011
|
+
options = QFileDialog.Options()
|
1012
|
+
filename, _ = QFileDialog.getOpenFileName(
|
1013
|
+
self.native,
|
1014
|
+
"Open Pickle File...",
|
1015
|
+
"",
|
1016
|
+
"Pickle Files (*.pickle);;All Files (*)",
|
1017
|
+
options=options,
|
1018
|
+
)
|
1019
|
+
if filename:
|
1020
|
+
self._load_data(filename)
|
1021
|
+
|
1022
|
+
def _load_data(self, filename):
|
1023
|
+
data = load_pickle(filename)
|
1024
|
+
|
1025
|
+
_ = self.viewer.add_image(data=data[2], name="Rotations", colormap="orange")
|
1026
|
+
|
1027
|
+
_ = self.viewer.add_image(data=data[0], name="Scores", colormap="turbo")
|
1028
|
+
|
1029
|
+
|
852
1030
|
def main():
|
853
1031
|
viewer = napari.Viewer()
|
854
1032
|
|
@@ -856,10 +1034,16 @@ def main():
|
|
856
1034
|
mask_widget = MaskWidget(viewer)
|
857
1035
|
export_widget = ExportWidget(viewer)
|
858
1036
|
point_cloud = PointCloudWidget(viewer)
|
1037
|
+
matching_widget = MatchingWidget(viewer)
|
1038
|
+
alignment_widget = AlignmentWidget(viewer)
|
859
1039
|
|
860
1040
|
viewer.window.add_dock_widget(widget=filter_widget, name="Preprocess", area="right")
|
1041
|
+
viewer.window.add_dock_widget(
|
1042
|
+
widget=alignment_widget, name="Alignment", area="right"
|
1043
|
+
)
|
861
1044
|
viewer.window.add_dock_widget(widget=mask_widget, name="Mask", area="right")
|
862
1045
|
viewer.window.add_dock_widget(widget=point_cloud, name="PointCloud", area="left")
|
1046
|
+
viewer.window.add_dock_widget(widget=matching_widget, name="Matching", area="left")
|
863
1047
|
|
864
1048
|
viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
|
865
1049
|
|
@@ -0,0 +1,218 @@
|
|
1
|
+
#!python3
|
2
|
+
""" CLI to refine template matching candidates.
|
3
|
+
|
4
|
+
Copyright (c) 2023-2024 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
import argparse
|
9
|
+
from time import time
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
from numpy.typing import NDArray
|
13
|
+
|
14
|
+
from tme.backends import backend
|
15
|
+
from tme import Density, Structure
|
16
|
+
from tme.matching_data import MatchingData
|
17
|
+
from tme.analyzer import MaxScoreOverRotations, MaxScoreOverTranslations
|
18
|
+
from tme.matching_utils import (
|
19
|
+
load_pickle,
|
20
|
+
get_rotation_matrices,
|
21
|
+
compute_parallelization_schedule,
|
22
|
+
)
|
23
|
+
from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
|
24
|
+
|
25
|
+
from postprocess import Orientations
|
26
|
+
from match_template import load_and_validate_mask
|
27
|
+
|
28
|
+
def parse_args():
|
29
|
+
parser = argparse.ArgumentParser(
|
30
|
+
description="Refine Template Matching Orientations."
|
31
|
+
)
|
32
|
+
parser.add_argument(
|
33
|
+
"--input_file",
|
34
|
+
required=True,
|
35
|
+
type=str,
|
36
|
+
help="Path to the output of match_template.py.",
|
37
|
+
)
|
38
|
+
parser.add_argument(
|
39
|
+
"--orientations",
|
40
|
+
required=True,
|
41
|
+
type=str,
|
42
|
+
help="Path to orientations from postprocess.py.",
|
43
|
+
)
|
44
|
+
parser.add_argument(
|
45
|
+
"--output_file",
|
46
|
+
required=True,
|
47
|
+
help="Path to the refined output orientations.",
|
48
|
+
)
|
49
|
+
parser.add_argument(
|
50
|
+
"-n",
|
51
|
+
dest="cores",
|
52
|
+
required=False,
|
53
|
+
type=int,
|
54
|
+
default=4,
|
55
|
+
help="Number of cores used for template matching.",
|
56
|
+
)
|
57
|
+
|
58
|
+
args = parser.parse_args()
|
59
|
+
|
60
|
+
return args
|
61
|
+
|
62
|
+
|
63
|
+
def load_template(filepath: str, sampling_rate: NDArray) -> "Density":
|
64
|
+
try:
|
65
|
+
template = Density.from_file(filepath)
|
66
|
+
except ValueError:
|
67
|
+
template = Structure.from_file(filepath)
|
68
|
+
template = Density.from_structure(template, sampling_rate=sampling_rate)
|
69
|
+
|
70
|
+
return template
|
71
|
+
|
72
|
+
|
73
|
+
def main():
|
74
|
+
args = parse_args()
|
75
|
+
meta = load_pickle(args.input_file)[-1]
|
76
|
+
target_origin, _, sampling_rate, cli_args = meta
|
77
|
+
|
78
|
+
orientations = Orientations.from_file(
|
79
|
+
filename=args.orientations, file_format="text"
|
80
|
+
)
|
81
|
+
|
82
|
+
template = load_template(cli_args.template, sampling_rate)
|
83
|
+
template_mask = load_and_validate_mask(
|
84
|
+
mask_target=template, mask_path=cli_args.template_mask
|
85
|
+
)
|
86
|
+
|
87
|
+
if not cli_args.no_centering:
|
88
|
+
template, translation = template.centered(0)
|
89
|
+
|
90
|
+
if template_mask is None:
|
91
|
+
template_mask = template.empty
|
92
|
+
if not cli_args.no_centering:
|
93
|
+
enclosing_box = template.minimum_enclosing_box(
|
94
|
+
0, use_geometric_center=False
|
95
|
+
)
|
96
|
+
template_mask.adjust_box(enclosing_box)
|
97
|
+
|
98
|
+
template_mask.data[:] = 1
|
99
|
+
translation = np.zeros_like(translation)
|
100
|
+
|
101
|
+
template_mask.pad(template.shape, center=False)
|
102
|
+
origin_translation = np.divide(
|
103
|
+
np.subtract(template.origin, template_mask.origin), template.sampling_rate
|
104
|
+
)
|
105
|
+
translation = np.add(translation, origin_translation)
|
106
|
+
|
107
|
+
template_mask = template_mask.rigid_transform(
|
108
|
+
rotation_matrix=np.eye(template_mask.data.ndim),
|
109
|
+
translation=-translation,
|
110
|
+
order=1,
|
111
|
+
)
|
112
|
+
template_mask.origin = template.origin.copy()
|
113
|
+
|
114
|
+
target = Density.from_file(cli_args.target)
|
115
|
+
peaks = orientations.translations.astype(int)
|
116
|
+
half_shape = np.divide(template.shape, 2).astype(int)
|
117
|
+
observation_starts = np.subtract(peaks, half_shape)
|
118
|
+
observation_stops = np.add(peaks, half_shape) + np.mod(template.shape, 2).astype(
|
119
|
+
int
|
120
|
+
)
|
121
|
+
|
122
|
+
pruned_starts = np.maximum(observation_starts, 0)
|
123
|
+
pruned_stops = np.minimum(observation_stops, target.shape)
|
124
|
+
|
125
|
+
keep_peaks = (
|
126
|
+
np.sum(
|
127
|
+
np.multiply(
|
128
|
+
observation_starts == pruned_starts, observation_stops == pruned_stops
|
129
|
+
),
|
130
|
+
axis=1,
|
131
|
+
)
|
132
|
+
== observation_starts.shape[1]
|
133
|
+
)
|
134
|
+
observation_starts = observation_starts[keep_peaks]
|
135
|
+
observation_stops = observation_stops[keep_peaks]
|
136
|
+
|
137
|
+
observation_slices = [
|
138
|
+
tuple(slice(s, e) for s, e in zip(start_row, stop_row))
|
139
|
+
for start_row, stop_row in zip(observation_starts, observation_stops)
|
140
|
+
]
|
141
|
+
|
142
|
+
matching_data = MatchingData(target=target, template=template)
|
143
|
+
matching_data.rotations = np.eye(template.data.ndim).reshape(1, 3, 3)
|
144
|
+
|
145
|
+
target_pad = matching_data.target_padding(pad_target=True)
|
146
|
+
out_shape = np.add(target_pad, template.shape).astype(int)
|
147
|
+
|
148
|
+
observations = np.zeros((len(observation_slices), *out_shape))
|
149
|
+
|
150
|
+
|
151
|
+
for idx, obs_slice in enumerate(observation_slices):
|
152
|
+
subset = matching_data.subset_by_slice(
|
153
|
+
target_slice=obs_slice,
|
154
|
+
target_pad=target_pad,
|
155
|
+
invert_target=cli_args.invert_target_contrast,
|
156
|
+
)
|
157
|
+
xd = template.copy()
|
158
|
+
xd.pad(subset.target.shape, center = True)
|
159
|
+
# observations[idx] = subset.target
|
160
|
+
observations[idx] = xd.data
|
161
|
+
|
162
|
+
|
163
|
+
matching_data = MatchingData(target=observations, template=template)
|
164
|
+
matching_data._set_batch_dimension(target_dims=0, template_dims=None)
|
165
|
+
matching_data.rotations = get_rotation_matrices(15)
|
166
|
+
if template_mask is not None:
|
167
|
+
matching_data.template_mask = template_mask.data
|
168
|
+
|
169
|
+
template_box = np.ones(len(matching_data._output_template_shape), dtype=int)
|
170
|
+
target_padding = np.zeros_like(matching_data._output_target_shape)
|
171
|
+
|
172
|
+
scoring_method = "FLC"
|
173
|
+
callback_class = MaxScoreOverRotations
|
174
|
+
splits, schedule = compute_parallelization_schedule(
|
175
|
+
shape1=matching_data._output_target_shape,
|
176
|
+
shape2=template_box,
|
177
|
+
shape1_padding=target_padding,
|
178
|
+
max_cores=args.cores,
|
179
|
+
max_ram=backend.get_available_memory(),
|
180
|
+
split_only_outer=False,
|
181
|
+
matching_method=scoring_method,
|
182
|
+
analyzer_method=callback_class.__name__,
|
183
|
+
backend=backend._backend_name,
|
184
|
+
float_nbytes=backend.datatype_bytes(backend._default_dtype),
|
185
|
+
complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
|
186
|
+
integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
|
187
|
+
)
|
188
|
+
|
189
|
+
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[scoring_method]
|
190
|
+
|
191
|
+
start = time()
|
192
|
+
candidates = scan_subsets(
|
193
|
+
matching_data=matching_data,
|
194
|
+
matching_score=matching_score,
|
195
|
+
matching_setup=matching_setup,
|
196
|
+
callback_class=callback_class,
|
197
|
+
callback_class_args={
|
198
|
+
# "score_space_shape" : (
|
199
|
+
# matching_data.rotations.shape[0], observations.shape[0]
|
200
|
+
# ),
|
201
|
+
# "score_space_dtype" : backend._default_dtype,
|
202
|
+
# "template_shape" : (matching_data.rotations.shape[0], *matching_data._template.shape)
|
203
|
+
},
|
204
|
+
target_splits=splits,
|
205
|
+
job_schedule=schedule,
|
206
|
+
pad_target_edges=False,
|
207
|
+
pad_fourier=False,
|
208
|
+
interpolation_order=cli_args.interpolation_order,
|
209
|
+
)
|
210
|
+
print(candidates[0].max())
|
211
|
+
Density(candidates[0][0]).to_file("scores.mrc")
|
212
|
+
|
213
|
+
runtime = time() - start
|
214
|
+
print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
|
215
|
+
|
216
|
+
|
217
|
+
if __name__ == "__main__":
|
218
|
+
main()
|
tme/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
|
+
from . import extensions
|
1
2
|
from .__version__ import __version__
|
2
3
|
from .density import Density
|
3
4
|
from .preprocessor import Preprocessor
|
4
5
|
from .structure import Structure
|
6
|
+
from .orientations import Orientations
|
5
7
|
from .matching_optimization import FitRefinement
|
6
|
-
from . import extensions
|
tme/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "0.2.0"
|