pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
  3. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
  4. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
  6. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
  8. pytme-0.3.1.dist-info/RECORD +133 -0
  9. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +118 -99
  14. scripts/match_template.py +177 -226
  15. scripts/match_template_filters.py +1200 -0
  16. scripts/postprocess.py +69 -47
  17. scripts/preprocess.py +10 -23
  18. scripts/preprocessor_gui.py +98 -28
  19. scripts/pytme_runner.py +1223 -0
  20. scripts/refine_matches.py +156 -387
  21. tests/data/.DS_Store +0 -0
  22. tests/data/Blurring/.DS_Store +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Raw/.DS_Store +0 -0
  25. tests/data/Structures/.DS_Store +0 -0
  26. tests/preprocessing/test_frequency_filters.py +19 -10
  27. tests/preprocessing/test_utils.py +18 -0
  28. tests/test_analyzer.py +122 -122
  29. tests/test_backends.py +4 -9
  30. tests/test_density.py +0 -1
  31. tests/test_matching_cli.py +30 -30
  32. tests/test_matching_data.py +5 -5
  33. tests/test_matching_utils.py +11 -61
  34. tests/test_rotations.py +1 -1
  35. tme/__version__.py +1 -1
  36. tme/analyzer/__init__.py +1 -1
  37. tme/analyzer/_utils.py +5 -8
  38. tme/analyzer/aggregation.py +28 -9
  39. tme/analyzer/base.py +25 -36
  40. tme/analyzer/peaks.py +49 -122
  41. tme/analyzer/proxy.py +1 -0
  42. tme/backends/_jax_utils.py +31 -28
  43. tme/backends/_numpyfftw_utils.py +270 -0
  44. tme/backends/cupy_backend.py +11 -54
  45. tme/backends/jax_backend.py +72 -48
  46. tme/backends/matching_backend.py +6 -51
  47. tme/backends/mlx_backend.py +1 -27
  48. tme/backends/npfftw_backend.py +95 -90
  49. tme/backends/pytorch_backend.py +5 -26
  50. tme/density.py +7 -10
  51. tme/extensions.cpython-311-darwin.so +0 -0
  52. tme/filters/__init__.py +2 -2
  53. tme/filters/_utils.py +32 -7
  54. tme/filters/bandpass.py +225 -186
  55. tme/filters/ctf.py +138 -87
  56. tme/filters/reconstruction.py +38 -9
  57. tme/filters/wedge.py +98 -112
  58. tme/filters/whitening.py +1 -6
  59. tme/mask.py +341 -0
  60. tme/matching_data.py +20 -44
  61. tme/matching_exhaustive.py +46 -56
  62. tme/matching_optimization.py +2 -1
  63. tme/matching_scores.py +216 -412
  64. tme/matching_utils.py +82 -424
  65. tme/memory.py +1 -1
  66. tme/orientations.py +16 -8
  67. tme/parser.py +109 -29
  68. tme/preprocessor.py +2 -2
  69. tme/rotations.py +1 -1
  70. pytme-0.3b0.dist-info/RECORD +0 -122
  71. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  72. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  73. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
@@ -19,14 +19,14 @@ from magicgui import widgets
19
19
  from numpy.typing import NDArray
20
20
  from napari.layers import Image
21
21
  from scipy.fft import next_fast_len
22
- from qtpy.QtWidgets import QFileDialog
23
22
  from napari.utils.events import EventedList
23
+ from qtpy.QtWidgets import QFileDialog, QMessageBox
24
24
 
25
- from tme.backends import backend
25
+ from tme.backends import backend as be
26
26
  from tme.rotations import align_vectors
27
- from tme import Preprocessor, Density, Orientations
28
27
  from tme.matching_utils import create_mask, load_pickle
29
- from tme.filters import BandPassFilter, CTFReconstructed
28
+ from tme import Preprocessor, Density, Orientations
29
+ from tme.filters import BandPassReconstructed, CTFReconstructed
30
30
 
31
31
  preprocessor = Preprocessor()
32
32
  SLIDER_MIN, SLIDER_MAX = 0, 25
@@ -43,17 +43,16 @@ def bandpass_filter(
43
43
  hard_edges: bool = False,
44
44
  sampling_rate=None,
45
45
  ) -> NDArray:
46
- bpf = BandPassFilter(
46
+ bpf = BandPassReconstructed(
47
47
  lowpass=lowpass_angstrom,
48
48
  highpass=highpass_angstrom,
49
49
  sampling_rate=np.max(sampling_rate),
50
50
  use_gaussian=not hard_edges,
51
- shape_is_real_fourier=True,
52
51
  return_real_fourier=True,
53
52
  )
54
53
  template_ft = np.fft.rfftn(template, s=template.shape)
55
54
 
56
- mask = bpf(shape=template_ft.shape)["data"]
55
+ mask = bpf(shape=template.shape)["data"]
57
56
  np.multiply(template_ft, mask, out=template_ft)
58
57
  return np.fft.irfftn(template_ft, s=template.shape).real
59
58
 
@@ -70,7 +69,7 @@ def ctf_filter(
70
69
  flip_phase: bool = False,
71
70
  ) -> NDArray:
72
71
  fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
73
- template_pad = backend.topleft_pad(template, fast_shape)
72
+ template_pad = be.topleft_pad(template, fast_shape)
74
73
  template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
75
74
  ctf = CTFReconstructed(
76
75
  shape=fast_shape,
@@ -86,7 +85,7 @@ def ctf_filter(
86
85
  )
87
86
  np.multiply(template_ft, ctf()["data"], out=template_ft)
88
87
  template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
89
- template = backend.topleft_pad(template_pad, template.shape)
88
+ template = be.topleft_pad(template_pad, template.shape)
90
89
  return template
91
90
 
92
91
 
@@ -393,7 +392,13 @@ class FilterWidget(widgets.Container):
393
392
  metadata = selected_layer_metadata.copy()
394
393
  if "filter_parameters" not in metadata:
395
394
  metadata["filter_parameters"] = []
396
- metadata["filter_parameters"].append({filter_name: kwargs.copy()})
395
+
396
+ payload = {filter_name: kwargs.copy()}
397
+ if isinstance(metadata["filter_parameters"], dict):
398
+ metadata["filter_parameters"].update(payload)
399
+ else:
400
+ metadata["filter_parameters"].append(payload)
401
+
397
402
  metadata["used_filter"] = filter_name
398
403
  new_layer.metadata = metadata
399
404
 
@@ -451,7 +456,29 @@ def box_mask(
451
456
  mask_type="box",
452
457
  shape=template.shape,
453
458
  center=(center_x, center_y, center_z),
454
- height=(height_x, height_y, height_z),
459
+ size=(height_x, height_y, height_z),
460
+ sigma_decay=sigma_decay,
461
+ )
462
+
463
+
464
+ def membrane_mask(
465
+ template: NDArray,
466
+ symmetry_axis: int,
467
+ center_x: float,
468
+ center_y: float,
469
+ center_z: float,
470
+ radius: float,
471
+ thickness: float = 1,
472
+ separation: float = 3,
473
+ sigma_decay: float = 0,
474
+ **kwargs,
475
+ ) -> NDArray:
476
+ return create_mask(
477
+ mask_type="membrane",
478
+ shape=template.shape,
479
+ radius=radius,
480
+ thickness=thickness,
481
+ separation=separation,
455
482
  sigma_decay=sigma_decay,
456
483
  )
457
484
 
@@ -472,7 +499,7 @@ def tube_mask(
472
499
  mask_type="tube",
473
500
  shape=template.shape,
474
501
  symmetry_axis=symmetry_axis,
475
- base_center=(center_x, center_y, center_z),
502
+ center=(center_x, center_y, center_z),
476
503
  inner_radius=inner_radius,
477
504
  outer_radius=outer_radius,
478
505
  height=height,
@@ -585,6 +612,7 @@ class MaskWidget(widgets.Container):
585
612
  "Ellipsoid": ellipsod_mask,
586
613
  "Tube": tube_mask,
587
614
  "Box": box_mask,
615
+ "Membrane": membrane_mask,
588
616
  "Wedge": wedge_mask,
589
617
  "Threshold": threshold_mask,
590
618
  "Lowpass": lowpass_mask,
@@ -818,7 +846,7 @@ class AlignmentWidget(widgets.Container):
818
846
  principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
819
847
 
820
848
  rotation_matrix = align_vectors(principal_eigenvector, alignment_axis)
821
- rotated_data, _ = backend.rigid_transform(
849
+ rotated_data, _ = be.rigid_transform(
822
850
  arr=active_layer.data,
823
851
  rotation_matrix=rotation_matrix,
824
852
  use_geometric_center=False,
@@ -954,7 +982,6 @@ class PointCloudWidget(widgets.Container):
954
982
  if not isinstance(layer, napari.layers.Points):
955
983
  continue
956
984
 
957
- layer.face_color = "white"
958
985
  if event == "Label":
959
986
  if len(layer.properties.get("detail", ())) == 0:
960
987
  continue
@@ -971,9 +998,7 @@ class PointCloudWidget(widgets.Container):
971
998
  layer.face_color = "score_scaled"
972
999
  layer.face_colormap = "turbo"
973
1000
  layer.face_color_mode = "colormap"
974
-
975
1001
  layer.refresh_colors()
976
-
977
1002
  return None
978
1003
 
979
1004
  def _set_positive(self, event):
@@ -1141,9 +1166,18 @@ class MatchingWidget(widgets.Container):
1141
1166
  self.viewer = viewer
1142
1167
  self.dataframes = {}
1143
1168
 
1169
+ option_container = widgets.Container(layout="horizontal")
1170
+ self.load_target_checkbox = widgets.CheckBox(text="Load Target", value=False)
1171
+ self.load_rotations_checkbox = widgets.CheckBox(
1172
+ text="Load Rotations", value=False
1173
+ )
1174
+ option_container.append(self.load_target_checkbox)
1175
+ option_container.append(self.load_rotations_checkbox)
1176
+
1144
1177
  self.import_button = widgets.PushButton(name="Import", text="Import Pickle")
1145
1178
  self.import_button.clicked.connect(self._get_load_path)
1146
1179
 
1180
+ self.append(option_container)
1147
1181
  self.append(self.import_button)
1148
1182
 
1149
1183
  def _get_load_path(self, event):
@@ -1151,7 +1185,7 @@ class MatchingWidget(widgets.Container):
1151
1185
  self.native,
1152
1186
  "Open Pickle File...",
1153
1187
  "",
1154
- "Pickle Files (*.pickle);;All Files (*)",
1188
+ "Pickle Files (*.pickle *pickle.gz);;All Files (*)",
1155
1189
  )
1156
1190
  if filename:
1157
1191
  self._load_data(filename)
@@ -1160,14 +1194,35 @@ class MatchingWidget(widgets.Container):
1160
1194
  data = load_pickle(filename)
1161
1195
 
1162
1196
  fname = basename(filename).replace(".pickle", "")
1197
+
1198
+ if self.load_target_checkbox.value:
1199
+ try:
1200
+ target = Density.from_file(data[-1][-1].target)
1201
+ _ = self.viewer.add_image(
1202
+ data=target.data,
1203
+ name=f"{fname}_target",
1204
+ metadata={
1205
+ "origin": target.origin,
1206
+ "sampling_rate": target.sampling_rate,
1207
+ },
1208
+ )
1209
+ except Exception as e:
1210
+ msg = QMessageBox(self.native)
1211
+ msg.setIcon(QMessageBox.Warning)
1212
+ msg.setWindowTitle("Loading Error")
1213
+ msg.setText(str(e))
1214
+ msg.setStandardButtons(QMessageBox.Ok)
1215
+ msg.exec_()
1216
+
1163
1217
  if data[0].ndim == data[2].ndim:
1164
1218
  metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
1165
- _ = self.viewer.add_image(
1166
- data=data[2],
1167
- name=f"{fname}_rotations",
1168
- colormap="orange",
1169
- metadata=metadata,
1170
- )
1219
+ if self.load_rotations_checkbox.value:
1220
+ _ = self.viewer.add_image(
1221
+ data=data[2],
1222
+ name=f"{fname}_rotations",
1223
+ colormap="orange",
1224
+ metadata=metadata,
1225
+ )
1171
1226
  _ = self.viewer.add_image(
1172
1227
  data=data[0],
1173
1228
  name=f"{fname}_scores",
@@ -1175,11 +1230,8 @@ class MatchingWidget(widgets.Container):
1175
1230
  metadata=metadata,
1176
1231
  )
1177
1232
  return None
1178
- detail = np.zeros_like(data[2])
1179
- else:
1180
- detail = data[3]
1181
1233
 
1182
- point_properties = {"score": data[2], "detail": detail}
1234
+ point_properties = {"score": data[2], "detail": data[3]}
1183
1235
  point_properties["score_scaled"] = np.log1p(
1184
1236
  point_properties["score"] - point_properties["score"].min()
1185
1237
  )
@@ -1192,8 +1244,26 @@ class MatchingWidget(widgets.Container):
1192
1244
  )
1193
1245
 
1194
1246
 
1247
+ class CustomNapariViewer(napari.Viewer):
1248
+ """
1249
+ Custom viewer to ensure 3D image layers are by default shown as xy projection.
1250
+ """
1251
+
1252
+ def add_image(self, data, **kwargs):
1253
+ viewer_ndim = len(self.dims.order)
1254
+ layer = super().add_image(data, **kwargs)
1255
+
1256
+ try:
1257
+ # Set to xy view the first time data is opened
1258
+ if viewer_ndim != 3 and data.ndim == 3:
1259
+ self.dims.order = (2, 0, 1)
1260
+ except Exception:
1261
+ pass
1262
+ return layer
1263
+
1264
+
1195
1265
  def main():
1196
- viewer = napari.Viewer()
1266
+ viewer = CustomNapariViewer()
1197
1267
 
1198
1268
  filter_widget = FilterWidget(preprocessor, viewer)
1199
1269
  mask_widget = MaskWidget(viewer)