pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,5 @@
1
1
  #!python3
2
- """ Simplify picking adequate filtering and masking parameters using a GUI.
3
- Exposes tme.preprocessor.Preprocessor and tme.fitter_utils member functions
4
- to achieve this aim.
2
+ """ GUI for identifying adequate template matching filter and masks.
5
3
 
6
4
  Copyright (c) 2023 European Molecular Biology Laboratory
7
5
 
@@ -12,17 +10,20 @@ import argparse
12
10
  from typing import Tuple, Callable, List
13
11
  from typing_extensions import Annotated
14
12
 
13
+ import napari
15
14
  import numpy as np
16
15
  import pandas as pd
17
- import napari
16
+ from scipy.fft import next_fast_len
18
17
  from napari.layers import Image
19
18
  from napari.utils.events import EventedList
20
-
21
19
  from magicgui import widgets
22
20
  from qtpy.QtWidgets import QFileDialog
23
21
  from numpy.typing import NDArray
24
22
 
23
+ from tme.backends import backend
25
24
  from tme import Preprocessor, Density
25
+ from tme.preprocessing import BandPassFilter
26
+ from tme.preprocessing.tilt_series import CTF
26
27
  from tme.matching_utils import create_mask, load_pickle
27
28
 
28
29
  preprocessor = Preprocessor()
@@ -35,19 +36,57 @@ def gaussian_filter(template: NDArray, sigma: float, **kwargs: dict) -> NDArray:
35
36
 
36
37
  def bandpass_filter(
37
38
  template: NDArray,
38
- minimum_frequency: float,
39
- maximum_frequency: float,
40
- gaussian_sigma: float,
41
- **kwargs: dict,
39
+ lowpass_angstrom: float = 30,
40
+ highpass_angstrom: float = 140,
41
+ hard_edges: bool = False,
42
+ sampling_rate=None,
42
43
  ) -> NDArray:
43
- return preprocessor.bandpass_filter(
44
- template=template,
45
- minimum_frequency=minimum_frequency,
46
- maximum_frequency=maximum_frequency,
47
- sampling_rate=1,
48
- gaussian_sigma=gaussian_sigma,
49
- **kwargs,
44
+ bpf = BandPassFilter(
45
+ lowpass=lowpass_angstrom,
46
+ highpass=highpass_angstrom,
47
+ sampling_rate=np.max(sampling_rate),
48
+ use_gaussian=not hard_edges,
49
+ shape_is_real_fourier=True,
50
+ return_real_fourier=True,
51
+ )
52
+ template_ft = np.fft.rfftn(template, s=template.shape)
53
+
54
+ mask = bpf(shape=template_ft.shape)["data"]
55
+ np.multiply(template_ft, mask, out=template_ft)
56
+ return np.fft.irfftn(template_ft, s=template.shape).real
57
+
58
+
59
+ def ctf_filter(
60
+ template: NDArray,
61
+ defocus_angstrom: float = 30000,
62
+ acceleration_voltage: float = 300,
63
+ spherical_aberration: float = 2.7,
64
+ amplitude_contrast: float = 0.07,
65
+ phase_shift: float = 0,
66
+ defocus_angle: float = 0,
67
+ sampling_rate=None,
68
+ flip_phase: bool = False,
69
+ ) -> NDArray:
70
+ fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
71
+ template_pad = backend.topleft_pad(template, fast_shape)
72
+ template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
73
+ ctf = CTF(
74
+ angles=[0],
75
+ shape=fast_shape,
76
+ defocus_x=[defocus_angstrom],
77
+ acceleration_voltage=acceleration_voltage * 1e3,
78
+ spherical_aberration=spherical_aberration * 1e7,
79
+ amplitude_contrast=amplitude_contrast,
80
+ phase_shift=[phase_shift],
81
+ defocus_angle=[defocus_angle],
82
+ sampling_rate=np.max(sampling_rate),
83
+ return_real_fourier=True,
84
+ flip_phase=flip_phase,
50
85
  )
86
+ np.multiply(template_ft, ctf()["data"], out=template_ft)
87
+ template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
88
+ template = backend.topleft_pad(template_pad, template.shape)
89
+ return template
51
90
 
52
91
 
53
92
  def difference_of_gaussian_filter(
@@ -109,61 +148,6 @@ def mean(
109
148
  return preprocessor.mean_filter(template=template, width=width)
110
149
 
111
150
 
112
- def resolution_sphere(
113
- template: NDArray,
114
- cutoff_angstrom: float,
115
- highpass: bool = False,
116
- sampling_rate=None,
117
- ) -> NDArray:
118
- if cutoff_angstrom == 0:
119
- return template
120
-
121
- cutoff_frequency = np.max(2 * sampling_rate / cutoff_angstrom)
122
-
123
- min_freq, max_freq = 0, cutoff_frequency
124
- if highpass:
125
- min_freq, max_freq = cutoff_frequency, 1e10
126
-
127
- mask = preprocessor.bandpass_mask(
128
- shape=template.shape,
129
- minimum_frequency=min_freq,
130
- maximum_frequency=max_freq,
131
- omit_negative_frequencies=False,
132
- )
133
-
134
- template_ft = np.fft.fftn(template)
135
- np.multiply(template_ft, mask, out=template_ft)
136
- return np.fft.ifftn(template_ft).real
137
-
138
-
139
- def resolution_gaussian(
140
- template: NDArray,
141
- cutoff_angstrom: float,
142
- highpass: bool = False,
143
- sampling_rate=None,
144
- ) -> NDArray:
145
- if cutoff_angstrom == 0:
146
- return template
147
-
148
- grid = preprocessor.fftfreqn(
149
- shape=template.shape, sampling_rate=sampling_rate / sampling_rate.max()
150
- )
151
-
152
- sigma_fourier = np.divide(
153
- np.max(2 * sampling_rate / cutoff_angstrom), np.sqrt(2 * np.log(2))
154
- )
155
-
156
- mask = np.exp(-(grid**2) / (2 * sigma_fourier**2))
157
- if highpass:
158
- mask = 1 - mask
159
-
160
- mask = np.fft.ifftshift(mask)
161
-
162
- template_ft = np.fft.fftn(template)
163
- np.multiply(template_ft, mask, out=template_ft)
164
- return np.fft.ifftn(template_ft).real
165
-
166
-
167
151
  def wedge(
168
152
  template: NDArray,
169
153
  tilt_start: float,
@@ -274,8 +258,7 @@ WRAPPED_FUNCTIONS = {
274
258
  "mean_filter": mean,
275
259
  "wedge_filter": wedge,
276
260
  "power_spectrum": compute_power_spectrum,
277
- "resolution_gaussian": resolution_gaussian,
278
- "resolution_sphere": resolution_sphere,
261
+ "ctf": ctf_filter,
279
262
  }
280
263
 
281
264
  EXCLUDED_FUNCTIONS = [
@@ -421,6 +404,7 @@ def sphere_mask(
421
404
  center_y: float,
422
405
  center_z: float,
423
406
  radius: float,
407
+ sigma_decay: float = 0,
424
408
  **kwargs,
425
409
  ) -> NDArray:
426
410
  return create_mask(
@@ -428,6 +412,7 @@ def sphere_mask(
428
412
  shape=template.shape,
429
413
  center=(center_x, center_y, center_z),
430
414
  radius=radius,
415
+ sigma_decay=sigma_decay,
431
416
  )
432
417
 
433
418
 
@@ -439,6 +424,7 @@ def ellipsod_mask(
439
424
  radius_x: float,
440
425
  radius_y: float,
441
426
  radius_z: float,
427
+ sigma_decay: float = 0,
442
428
  **kwargs,
443
429
  ) -> NDArray:
444
430
  return create_mask(
@@ -446,6 +432,7 @@ def ellipsod_mask(
446
432
  shape=template.shape,
447
433
  center=(center_x, center_y, center_z),
448
434
  radius=(radius_x, radius_y, radius_z),
435
+ sigma_decay=sigma_decay,
449
436
  )
450
437
 
451
438
 
@@ -457,6 +444,7 @@ def box_mask(
457
444
  height_x: int,
458
445
  height_y: int,
459
446
  height_z: int,
447
+ sigma_decay: float = 0,
460
448
  **kwargs,
461
449
  ) -> NDArray:
462
450
  return create_mask(
@@ -464,6 +452,7 @@ def box_mask(
464
452
  shape=template.shape,
465
453
  center=(center_x, center_y, center_z),
466
454
  height=(height_x, height_y, height_z),
455
+ sigma_decay=sigma_decay,
467
456
  )
468
457
 
469
458
 
@@ -476,6 +465,7 @@ def tube_mask(
476
465
  inner_radius: float,
477
466
  outer_radius: float,
478
467
  height: int,
468
+ sigma_decay: float = 0,
479
469
  **kwargs,
480
470
  ) -> NDArray:
481
471
  return create_mask(
@@ -486,6 +476,7 @@ def tube_mask(
486
476
  inner_radius=inner_radius,
487
477
  outer_radius=outer_radius,
488
478
  height=height,
479
+ sigma_decay=sigma_decay,
489
480
  )
490
481
 
491
482
 
@@ -533,13 +524,23 @@ def wedge_mask(
533
524
 
534
525
 
535
526
  def threshold_mask(
536
- template: NDArray, standard_deviation: float = 5.0, invert: bool = False, **kwargs
527
+ template: NDArray,
528
+ invert: bool = False,
529
+ standard_deviation: float = 5.0,
530
+ sigma: float = 0.0,
531
+ **kwargs,
537
532
  ) -> NDArray:
538
533
  template_mean = template.mean()
539
534
  template_deviation = standard_deviation * template.std()
540
535
  upper = template_mean + template_deviation
541
536
  lower = template_mean - template_deviation
542
- mask = np.logical_and(template > lower, template < upper)
537
+ mask = np.logical_or(template <= lower, template >= upper)
538
+
539
+ if sigma != 0:
540
+ mask_filter = preprocessor.gaussian_filter(template=mask * 1.0, sigma=sigma)
541
+ mask = np.add(mask, (1 - mask) * mask_filter)
542
+ mask[mask < np.exp(-np.square(sigma))] = 0
543
+
543
544
  if invert:
544
545
  np.invert(mask, out=mask)
545
546
 
@@ -789,7 +790,10 @@ class AlignmentWidget(widgets.Container):
789
790
  active_layer = self.viewer.layers.selection.active
790
791
  if active_layer is None:
791
792
  return ()
792
- return [i for i in range(active_layer.data.ndim)]
793
+ try:
794
+ return [i for i in range(active_layer.data.ndim)]
795
+ except Exception:
796
+ return ()
793
797
 
794
798
  def _update_align_axis(self, *args):
795
799
  self.align_axis.choices = self._get_active_layer_dims()
@@ -887,6 +891,7 @@ class PointCloudWidget(widgets.Container):
887
891
 
888
892
  self.viewer = viewer
889
893
  self.dataframes = {}
894
+ self.selected_category = -1
890
895
 
891
896
  self.import_button = widgets.PushButton(
892
897
  name="Import", text="Import Point Cloud"
@@ -899,10 +904,98 @@ class PointCloudWidget(widgets.Container):
899
904
  self.export_button.clicked.connect(self._export_point_cloud)
900
905
  self.export_button.enabled = False
901
906
 
907
+ self.annotation_container = widgets.Container(name="Label", layout="horizontal")
908
+ self.positive_button = widgets.PushButton(name="Positive", text="Positive")
909
+ self.negative_button = widgets.PushButton(name="Negative", text="Negative")
910
+ self.positive_button.clicked.connect(self._set_positive)
911
+ self.negative_button.clicked.connect(self._set_negative)
912
+ self.annotation_container.append(self.positive_button)
913
+ self.annotation_container.append(self.negative_button)
914
+
915
+ self.face_color_select = widgets.ComboBox(
916
+ name="Color", choices=["Label", "Score"], value=None, nullable=True
917
+ )
918
+ self.face_color_select.changed.connect(self._update_face_color_mode)
919
+
902
920
  self.append(self.import_button)
903
921
  self.append(self.export_button)
922
+ self.append(self.annotation_container)
923
+ self.append(self.face_color_select)
924
+
904
925
  self.viewer.layers.selection.events.changed.connect(self._update_buttons)
905
926
 
927
+ self.viewer.layers.events.inserted.connect(self._initialize_points_layer)
928
+
929
+ def _update_face_color_mode(self, event: str = None):
930
+ for layer in self.viewer.layers:
931
+ if not isinstance(layer, napari.layers.Points):
932
+ continue
933
+
934
+ layer.face_color = "white"
935
+ if event == "Label":
936
+ if len(layer.properties.get("detail", ())) == 0:
937
+ continue
938
+ layer.face_color = "detail"
939
+ layer.face_color_cycle = {
940
+ -1: "grey",
941
+ 0: "red",
942
+ 1: "green",
943
+ }
944
+ layer.face_color_mode = "cycle"
945
+ elif event == "Score":
946
+ if len(layer.properties.get("score_scaled", ())) == 0:
947
+ continue
948
+ layer.face_color = "score_scaled"
949
+ layer.face_colormap = "turbo"
950
+ layer.face_color_mode = "colormap"
951
+
952
+ layer.refresh_colors()
953
+
954
+ return None
955
+
956
+ def _set_positive(self, event):
957
+ self.selected_category = 1 if self.selected_category != 1 else -1
958
+ self._update_annotation_buttons()
959
+
960
+ def _set_negative(self, event):
961
+ self.selected_category = 0 if self.selected_category != 0 else -1
962
+ self._update_annotation_buttons()
963
+
964
+ def _update_annotation_buttons(self):
965
+ selected_style = "background-color: darkgrey"
966
+ default_style = "background-color: none"
967
+
968
+ self.positive_button.native.setStyleSheet(
969
+ selected_style if self.selected_category == 1 else default_style
970
+ )
971
+ self.negative_button.native.setStyleSheet(
972
+ selected_style if self.selected_category == 0 else default_style
973
+ )
974
+
975
+ def _initialize_points_layer(self, event):
976
+ layer = event.value
977
+ if not isinstance(layer, napari.layers.Points):
978
+ return
979
+ if len(layer.properties) == 0:
980
+ layer.properties = {"detail": [-1]}
981
+
982
+ if "detail" not in layer.properties:
983
+ layer["detail"] = [-1]
984
+
985
+ layer.mouse_drag_callbacks.append(self._on_click)
986
+ return None
987
+
988
+ def _on_click(self, layer, event):
989
+ if layer.mode == "add":
990
+ layer.current_properties["detail"][-1] = self.selected_category
991
+ elif layer.mode == "select":
992
+ for index in layer.selected_data:
993
+ layer.properties["detail"][index] = self.selected_category
994
+
995
+ # TODO: Check whether current face color is the desired one already
996
+ self._update_face_color_mode(self.face_color_select.value)
997
+ layer.refresh_colors()
998
+
906
999
  def _update_buttons(self, event):
907
1000
  is_pointcloud = isinstance(
908
1001
  self.viewer.layers.selection.active, napari.layers.Points
@@ -948,9 +1041,7 @@ class PointCloudWidget(widgets.Container):
948
1041
 
949
1042
  if "score" in merged_data.columns:
950
1043
  merged_data["score"] = merged_data["score"].fillna(1)
951
- if "detail" in merged_data.columns:
952
- merged_data["detail"] = merged_data["detail"].fillna(2)
953
-
1044
+ merged_data["detail"] = layer.properties["detail"]
954
1045
  merged_data.to_csv(filename, sep="\t", index=False)
955
1046
 
956
1047
  def _get_load_path(self, event):
@@ -974,7 +1065,7 @@ class PointCloudWidget(widgets.Container):
974
1065
  dataframe["score"] = 1
975
1066
 
976
1067
  if "detail" not in dataframe.columns:
977
- dataframe["detail"] = -2
1068
+ dataframe["detail"] = -1
978
1069
 
979
1070
  point_properties = {
980
1071
  "score": np.array(dataframe["score"].values),
@@ -988,8 +1079,6 @@ class PointCloudWidget(widgets.Container):
988
1079
  points,
989
1080
  size=10,
990
1081
  properties=point_properties,
991
- face_color="score_scaled",
992
- face_colormap="turbo",
993
1082
  name=layer_name,
994
1083
  )
995
1084
  self.dataframes[layer_name] = dataframe
@@ -1022,9 +1111,14 @@ class MatchingWidget(widgets.Container):
1022
1111
  def _load_data(self, filename):
1023
1112
  data = load_pickle(filename)
1024
1113
 
1025
- _ = self.viewer.add_image(data=data[2], name="Rotations", colormap="orange")
1114
+ metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
1115
+ _ = self.viewer.add_image(
1116
+ data=data[2], name="Rotations", colormap="orange", metadata=metadata
1117
+ )
1026
1118
 
1027
- _ = self.viewer.add_image(data=data[0], name="Scores", colormap="turbo")
1119
+ _ = self.viewer.add_image(
1120
+ data=data[0], name="Scores", colormap="turbo", metadata=metadata
1121
+ )
1028
1122
 
1029
1123
 
1030
1124
  def main():
@@ -1042,11 +1136,10 @@ def main():
1042
1136
  widget=alignment_widget, name="Alignment", area="right"
1043
1137
  )
1044
1138
  viewer.window.add_dock_widget(widget=mask_widget, name="Mask", area="right")
1139
+ viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
1045
1140
  viewer.window.add_dock_widget(widget=point_cloud, name="PointCloud", area="left")
1046
1141
  viewer.window.add_dock_widget(widget=matching_widget, name="Matching", area="left")
1047
1142
 
1048
- viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
1049
-
1050
1143
  napari.run()
1051
1144
 
1052
1145