pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.dev20250731__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. pytme-0.3.1.dev20250731.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/match_template.py +30 -41
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/postprocess.py +35 -21
  4. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocessor_gui.py +96 -24
  5. pytme-0.3.1.dev20250731.data/scripts/pytme_runner.py +1223 -0
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/METADATA +5 -7
  7. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/RECORD +59 -49
  8. scripts/estimate_ram_usage.py +97 -0
  9. scripts/extract_candidates.py +118 -99
  10. scripts/match_template.py +30 -41
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +35 -21
  13. scripts/preprocessor_gui.py +96 -24
  14. scripts/pytme_runner.py +644 -190
  15. scripts/refine_matches.py +158 -390
  16. tests/data/.DS_Store +0 -0
  17. tests/data/Blurring/.DS_Store +0 -0
  18. tests/data/Maps/.DS_Store +0 -0
  19. tests/data/Raw/.DS_Store +0 -0
  20. tests/data/Structures/.DS_Store +0 -0
  21. tests/preprocessing/test_utils.py +18 -0
  22. tests/test_analyzer.py +2 -3
  23. tests/test_backends.py +3 -9
  24. tests/test_density.py +0 -1
  25. tests/test_extensions.py +0 -1
  26. tests/test_matching_utils.py +10 -60
  27. tests/test_orientations.py +0 -12
  28. tests/test_rotations.py +1 -1
  29. tme/__version__.py +1 -1
  30. tme/analyzer/_utils.py +4 -4
  31. tme/analyzer/aggregation.py +35 -15
  32. tme/analyzer/peaks.py +11 -10
  33. tme/backends/_jax_utils.py +64 -18
  34. tme/backends/_numpyfftw_utils.py +270 -0
  35. tme/backends/cupy_backend.py +16 -55
  36. tme/backends/jax_backend.py +79 -40
  37. tme/backends/matching_backend.py +17 -51
  38. tme/backends/mlx_backend.py +1 -27
  39. tme/backends/npfftw_backend.py +71 -65
  40. tme/backends/pytorch_backend.py +1 -26
  41. tme/density.py +58 -5
  42. tme/extensions.cpython-311-darwin.so +0 -0
  43. tme/filters/ctf.py +22 -21
  44. tme/filters/wedge.py +10 -7
  45. tme/mask.py +341 -0
  46. tme/matching_data.py +31 -19
  47. tme/matching_exhaustive.py +37 -47
  48. tme/matching_optimization.py +2 -1
  49. tme/matching_scores.py +229 -411
  50. tme/matching_utils.py +73 -422
  51. tme/memory.py +1 -1
  52. tme/orientations.py +24 -13
  53. tme/rotations.py +1 -1
  54. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  55. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/estimate_memory_usage.py +0 -0
  56. {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocess.py +0 -0
  57. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/WHEEL +0 -0
  58. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/entry_points.txt +0 -0
  59. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/licenses/LICENSE +0 -0
  60. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/top_level.txt +0 -0
scripts/postprocess.py CHANGED
@@ -17,7 +17,7 @@ from scipy.special import erfcinv
17
17
 
18
18
  from tme import Density, Structure, Orientations
19
19
  from tme.cli import sanitize_name, print_block, print_entry
20
- from tme.matching_utils import load_pickle, centered_mask, write_pickle
20
+ from tme.matching_utils import load_pickle, center_slice, write_pickle
21
21
  from tme.matching_optimization import create_score_object, optimize_match
22
22
  from tme.rotations import euler_to_rotationmatrix, euler_from_rotationmatrix
23
23
  from tme.analyzer import (
@@ -87,6 +87,11 @@ def parse_args():
87
87
  help="Output prefix. Defaults to basename of first input. Extension is "
88
88
  "added with respect to chosen output format.",
89
89
  )
90
+ output_group.add_argument(
91
+ "--angles-clockwise",
92
+ action="store_true",
93
+ help="Report Euler angles in clockwise format expected by RELION.",
94
+ )
90
95
  output_group.add_argument(
91
96
  "--output-format",
92
97
  choices=[
@@ -112,7 +117,7 @@ def parse_args():
112
117
  peak_group.add_argument(
113
118
  "--peak-caller",
114
119
  choices=list(PEAK_CALLERS.keys()),
115
- default="PeakCallerScipy",
120
+ default="PeakCallerMaximumFilter",
116
121
  help="Peak caller for local maxima identification.",
117
122
  )
118
123
  peak_group.add_argument(
@@ -183,7 +188,7 @@ def parse_args():
183
188
  )
184
189
  additional_group.add_argument(
185
190
  "--n-false-positives",
186
- type=int,
191
+ type=float,
187
192
  default=None,
188
193
  required=False,
189
194
  help="Number of accepted false-positives picks to determine minimum score.",
@@ -313,11 +318,7 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
313
318
  data = load_matching_output(foreground)
314
319
  scores, _, rotations, rotation_mapping, *_ = data
315
320
 
316
- # We could normalize to unit sdev, but that might lead to unexpected
317
- # results for flat background distributions
318
- scores -= scores.mean()
319
321
  indices = tuple(slice(0, x) for x in scores.shape)
320
-
321
322
  indices_update = scores > scores_out[indices]
322
323
  scores_out[indices][indices_update] = scores[indices_update]
323
324
 
@@ -364,9 +365,8 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
364
365
  scores_norm = np.full(out_shape_norm, fill_value=0, dtype=np.float32)
365
366
  for background in backgrounds:
366
367
  data_norm = load_matching_output(background)
368
+ scores, _, rotations, rotation_mapping, *_ = data_norm
367
369
 
368
- scores = data_norm[0]
369
- scores -= scores.mean()
370
370
  indices = tuple(slice(0, x) for x in scores.shape)
371
371
  indices_update = scores > scores_norm[indices]
372
372
  scores_norm[indices][indices_update] = scores[indices_update]
@@ -375,8 +375,10 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
375
375
  update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
376
376
  scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
377
377
  scores_out[update] = data[0][update] - scores_norm[update]
378
- scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
379
378
  scores_out = np.fmax(scores_out, 0, out=scores_out)
379
+ scores_out[update] += scores_norm[update].mean()
380
+
381
+ # scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
380
382
  data[0] = scores_out
381
383
 
382
384
  fg, bg = simple_stats(data[0]), simple_stats(scores_norm)
@@ -478,15 +480,21 @@ def main():
478
480
  if orientations is None:
479
481
  translations, rotations, scores, details = [], [], [], []
480
482
 
481
- # Data processed by normalize_input is guaranteed to have this shape
482
- scores, offset, rotation_array, rotation_mapping, meta = data
483
+ var = None
484
+ # Data processed by normalize_input is guaranteed to have this shape)
485
+ scores, _, rotation_array, rotation_mapping, *_ = data
486
+ if len(data) == 6:
487
+ scores, _, rotation_array, rotation_mapping, var, *_ = data
483
488
 
484
489
  cropped_shape = np.subtract(
485
490
  scores.shape, np.multiply(args.min_boundary_distance, 2)
486
491
  ).astype(int)
487
492
 
488
493
  if args.min_boundary_distance > 0:
489
- scores = centered_mask(scores, new_shape=cropped_shape)
494
+ _scores = np.zeros_like(scores)
495
+ subset = center_slice(scores.shape, cropped_shape)
496
+ _scores[subset] = scores[subset]
497
+ scores = _scores
490
498
 
491
499
  if args.n_false_positives is not None:
492
500
  # Rickgauer et al. 2017
@@ -499,17 +507,20 @@ def main():
499
507
  )
500
508
  args.n_false_positives = max(args.n_false_positives, 1)
501
509
  n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
510
+ std = np.std(scores[cropped_slice])
511
+ if var is not None:
512
+ std = np.asarray(np.sqrt(var)).reshape(())
513
+
502
514
  minimum_score = np.multiply(
503
515
  erfcinv(2 * args.n_false_positives / n_correlations),
504
- np.sqrt(2) * np.std(scores[cropped_slice]),
516
+ np.sqrt(2) * std,
505
517
  )
506
- print(f"Determined minimum score cutoff: {minimum_score}.")
507
- minimum_score = max(minimum_score, 0)
508
- args.min_score = minimum_score
518
+ print(f"Determined cutoff --min-score {minimum_score}.")
519
+ args.min_score = max(minimum_score, 0)
509
520
 
510
521
  args.batch_dims = None
511
- if hasattr(cli_args, "target_batch"):
512
- args.batch_dims = cli_args.target_batch
522
+ if hasattr(cli_args, "batch_dims"):
523
+ args.batch_dims = cli_args.batch_dims
513
524
 
514
525
  peak_caller_kwargs = {
515
526
  "shape": scores.shape,
@@ -517,8 +528,8 @@ def main():
517
528
  "min_distance": args.min_distance,
518
529
  "min_boundary_distance": args.min_boundary_distance,
519
530
  "batch_dims": args.batch_dims,
520
- "minimum_score": args.min_score,
521
- "maximum_score": args.max_score,
531
+ "min_score": args.min_score,
532
+ "max_score": args.max_score,
522
533
  }
523
534
 
524
535
  peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
@@ -611,6 +622,9 @@ def main():
611
622
  orientations.rotations[index] = angles
612
623
  orientations.scores[index] = score * -1
613
624
 
625
+ if args.angles_clockwise:
626
+ orientations.rotations *= -1
627
+
614
628
  if args.output_format in ("orientations", "relion4", "relion5"):
615
629
  file_format, extension = "text", "tsv"
616
630
 
@@ -19,13 +19,13 @@ from magicgui import widgets
19
19
  from numpy.typing import NDArray
20
20
  from napari.layers import Image
21
21
  from scipy.fft import next_fast_len
22
- from qtpy.QtWidgets import QFileDialog
23
22
  from napari.utils.events import EventedList
23
+ from qtpy.QtWidgets import QFileDialog, QMessageBox
24
24
 
25
- from tme.backends import backend
25
+ from tme.backends import backend as be
26
26
  from tme.rotations import align_vectors
27
- from tme import Preprocessor, Density, Orientations
28
27
  from tme.matching_utils import create_mask, load_pickle
28
+ from tme import Preprocessor, Density, Orientations
29
29
  from tme.filters import BandPassReconstructed, CTFReconstructed
30
30
 
31
31
  preprocessor = Preprocessor()
@@ -69,7 +69,7 @@ def ctf_filter(
69
69
  flip_phase: bool = False,
70
70
  ) -> NDArray:
71
71
  fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
72
- template_pad = backend.topleft_pad(template, fast_shape)
72
+ template_pad = be.topleft_pad(template, fast_shape)
73
73
  template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
74
74
  ctf = CTFReconstructed(
75
75
  shape=fast_shape,
@@ -85,7 +85,7 @@ def ctf_filter(
85
85
  )
86
86
  np.multiply(template_ft, ctf()["data"], out=template_ft)
87
87
  template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
88
- template = backend.topleft_pad(template_pad, template.shape)
88
+ template = be.topleft_pad(template_pad, template.shape)
89
89
  return template
90
90
 
91
91
 
@@ -392,7 +392,13 @@ class FilterWidget(widgets.Container):
392
392
  metadata = selected_layer_metadata.copy()
393
393
  if "filter_parameters" not in metadata:
394
394
  metadata["filter_parameters"] = []
395
- metadata["filter_parameters"].append({filter_name: kwargs.copy()})
395
+
396
+ payload = {filter_name: kwargs.copy()}
397
+ if isinstance(metadata["filter_parameters"], dict):
398
+ metadata["filter_parameters"].update(payload)
399
+ else:
400
+ metadata["filter_parameters"].append(payload)
401
+
396
402
  metadata["used_filter"] = filter_name
397
403
  new_layer.metadata = metadata
398
404
 
@@ -450,7 +456,30 @@ def box_mask(
450
456
  mask_type="box",
451
457
  shape=template.shape,
452
458
  center=(center_x, center_y, center_z),
453
- height=(height_x, height_y, height_z),
459
+ size=(height_x, height_y, height_z),
460
+ sigma_decay=sigma_decay,
461
+ )
462
+
463
+
464
+ def membrane_mask(
465
+ template: NDArray,
466
+ symmetry_axis: int,
467
+ center_x: float,
468
+ center_y: float,
469
+ center_z: float,
470
+ radius: float,
471
+ thickness: float = 1,
472
+ separation: float = 3,
473
+ sigma_decay: float = 0,
474
+ **kwargs,
475
+ ) -> NDArray:
476
+ return create_mask(
477
+ center=(center_x, center_y, center_z),
478
+ mask_type="membrane",
479
+ shape=template.shape,
480
+ radius=radius,
481
+ thickness=thickness,
482
+ separation=separation,
454
483
  sigma_decay=sigma_decay,
455
484
  )
456
485
 
@@ -471,7 +500,7 @@ def tube_mask(
471
500
  mask_type="tube",
472
501
  shape=template.shape,
473
502
  symmetry_axis=symmetry_axis,
474
- base_center=(center_x, center_y, center_z),
503
+ center=(center_x, center_y, center_z),
475
504
  inner_radius=inner_radius,
476
505
  outer_radius=outer_radius,
477
506
  height=height,
@@ -584,6 +613,7 @@ class MaskWidget(widgets.Container):
584
613
  "Ellipsoid": ellipsod_mask,
585
614
  "Tube": tube_mask,
586
615
  "Box": box_mask,
616
+ "Membrane": membrane_mask,
587
617
  "Wedge": wedge_mask,
588
618
  "Threshold": threshold_mask,
589
619
  "Lowpass": lowpass_mask,
@@ -817,7 +847,7 @@ class AlignmentWidget(widgets.Container):
817
847
  principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
818
848
 
819
849
  rotation_matrix = align_vectors(principal_eigenvector, alignment_axis)
820
- rotated_data, _ = backend.rigid_transform(
850
+ rotated_data, _ = be.rigid_transform(
821
851
  arr=active_layer.data,
822
852
  rotation_matrix=rotation_matrix,
823
853
  use_geometric_center=False,
@@ -953,7 +983,6 @@ class PointCloudWidget(widgets.Container):
953
983
  if not isinstance(layer, napari.layers.Points):
954
984
  continue
955
985
 
956
- layer.face_color = "white"
957
986
  if event == "Label":
958
987
  if len(layer.properties.get("detail", ())) == 0:
959
988
  continue
@@ -970,9 +999,7 @@ class PointCloudWidget(widgets.Container):
970
999
  layer.face_color = "score_scaled"
971
1000
  layer.face_colormap = "turbo"
972
1001
  layer.face_color_mode = "colormap"
973
-
974
1002
  layer.refresh_colors()
975
-
976
1003
  return None
977
1004
 
978
1005
  def _set_positive(self, event):
@@ -1140,9 +1167,18 @@ class MatchingWidget(widgets.Container):
1140
1167
  self.viewer = viewer
1141
1168
  self.dataframes = {}
1142
1169
 
1170
+ option_container = widgets.Container(layout="horizontal")
1171
+ self.load_target_checkbox = widgets.CheckBox(text="Load Target", value=False)
1172
+ self.load_rotations_checkbox = widgets.CheckBox(
1173
+ text="Load Rotations", value=False
1174
+ )
1175
+ option_container.append(self.load_target_checkbox)
1176
+ option_container.append(self.load_rotations_checkbox)
1177
+
1143
1178
  self.import_button = widgets.PushButton(name="Import", text="Import Pickle")
1144
1179
  self.import_button.clicked.connect(self._get_load_path)
1145
1180
 
1181
+ self.append(option_container)
1146
1182
  self.append(self.import_button)
1147
1183
 
1148
1184
  def _get_load_path(self, event):
@@ -1150,7 +1186,7 @@ class MatchingWidget(widgets.Container):
1150
1186
  self.native,
1151
1187
  "Open Pickle File...",
1152
1188
  "",
1153
- "Pickle Files (*.pickle);;All Files (*)",
1189
+ "Pickle Files (*.pickle *pickle.gz);;All Files (*)",
1154
1190
  )
1155
1191
  if filename:
1156
1192
  self._load_data(filename)
@@ -1159,14 +1195,35 @@ class MatchingWidget(widgets.Container):
1159
1195
  data = load_pickle(filename)
1160
1196
 
1161
1197
  fname = basename(filename).replace(".pickle", "")
1198
+
1199
+ if self.load_target_checkbox.value:
1200
+ try:
1201
+ target = Density.from_file(data[-1][-1].target)
1202
+ _ = self.viewer.add_image(
1203
+ data=target.data,
1204
+ name=f"{fname}_target",
1205
+ metadata={
1206
+ "origin": target.origin,
1207
+ "sampling_rate": target.sampling_rate,
1208
+ },
1209
+ )
1210
+ except Exception as e:
1211
+ msg = QMessageBox(self.native)
1212
+ msg.setIcon(QMessageBox.Warning)
1213
+ msg.setWindowTitle("Loading Error")
1214
+ msg.setText(str(e))
1215
+ msg.setStandardButtons(QMessageBox.Ok)
1216
+ msg.exec_()
1217
+
1162
1218
  if data[0].ndim == data[2].ndim:
1163
1219
  metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
1164
- _ = self.viewer.add_image(
1165
- data=data[2],
1166
- name=f"{fname}_rotations",
1167
- colormap="orange",
1168
- metadata=metadata,
1169
- )
1220
+ if self.load_rotations_checkbox.value:
1221
+ _ = self.viewer.add_image(
1222
+ data=data[2],
1223
+ name=f"{fname}_rotations",
1224
+ colormap="orange",
1225
+ metadata=metadata,
1226
+ )
1170
1227
  _ = self.viewer.add_image(
1171
1228
  data=data[0],
1172
1229
  name=f"{fname}_scores",
@@ -1174,11 +1231,8 @@ class MatchingWidget(widgets.Container):
1174
1231
  metadata=metadata,
1175
1232
  )
1176
1233
  return None
1177
- detail = np.zeros_like(data[2])
1178
- else:
1179
- detail = data[3]
1180
1234
 
1181
- point_properties = {"score": data[2], "detail": detail}
1235
+ point_properties = {"score": data[2], "detail": data[3]}
1182
1236
  point_properties["score_scaled"] = np.log1p(
1183
1237
  point_properties["score"] - point_properties["score"].min()
1184
1238
  )
@@ -1191,8 +1245,26 @@ class MatchingWidget(widgets.Container):
1191
1245
  )
1192
1246
 
1193
1247
 
1248
+ class CustomNapariViewer(napari.Viewer):
1249
+ """
1250
+ Custom viewer to ensure 3D image layers are by default shown as xy projection.
1251
+ """
1252
+
1253
+ def add_image(self, data, **kwargs):
1254
+ viewer_ndim = len(self.dims.order)
1255
+ layer = super().add_image(data, **kwargs)
1256
+
1257
+ try:
1258
+ # Set to xy view the first time data is opened
1259
+ if viewer_ndim != 3 and data.ndim == 3:
1260
+ self.dims.order = (2, 0, 1)
1261
+ except Exception:
1262
+ pass
1263
+ return layer
1264
+
1265
+
1194
1266
  def main():
1195
- viewer = napari.Viewer()
1267
+ viewer = CustomNapariViewer()
1196
1268
 
1197
1269
  filter_widget = FilterWidget(preprocessor, viewer)
1198
1270
  mask_widget = MaskWidget(viewer)