pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3.0.data/scripts/match_template.py +1106 -0
  3. {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
  4. {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.post1.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
  6. pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
  7. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/METADATA +21 -20
  8. pytme-0.3.0.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +349 -378
  15. pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +320 -190
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +85 -19
  19. scripts/pytme_runner.py +771 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -54
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +395 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -204
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/extensions.cpython-311-darwin.so +0 -0
  49. tme/filters/__init__.py +3 -3
  50. tme/filters/_utils.py +36 -10
  51. tme/filters/bandpass.py +229 -188
  52. tme/filters/compose.py +5 -4
  53. tme/filters/ctf.py +516 -254
  54. tme/filters/reconstruction.py +91 -32
  55. tme/filters/wedge.py +196 -135
  56. tme/filters/whitening.py +37 -42
  57. tme/matching_data.py +28 -39
  58. tme/matching_exhaustive.py +31 -27
  59. tme/matching_optimization.py +5 -4
  60. tme/matching_scores.py +25 -15
  61. tme/matching_utils.py +193 -27
  62. tme/memory.py +4 -3
  63. tme/orientations.py +22 -9
  64. tme/parser.py +114 -33
  65. tme/preprocessor.py +6 -5
  66. tme/rotations.py +10 -7
  67. tme/structure.py +4 -3
  68. pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
  69. pytme-0.2.9.post1.dist-info/RECORD +0 -119
  70. pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
  71. scripts/estimate_ram_usage.py +0 -97
  72. tests/data/Maps/.DS_Store +0 -0
  73. tests/data/Structures/.DS_Store +0 -0
  74. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
  75. {pytme-0.2.9.post1.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
scripts/preprocess.py CHANGED
@@ -1,16 +1,17 @@
1
1
  #!python3
2
- """ Preprocessing routines for template matching.
2
+ """Preprocessing routines for template matching.
3
3
 
4
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
5
 
6
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
  import argparse
9
9
  import numpy as np
10
10
 
11
11
  from tme import Density, Structure
12
+ from tme.cli import print_entry
12
13
  from tme.backends import backend as be
13
- from tme.filters import BandPassFilter
14
+ from tme.filters import BandPassReconstructed
14
15
 
15
16
 
16
17
  def parse_args():
@@ -23,7 +24,6 @@ def parse_args():
23
24
  io_group.add_argument(
24
25
  "-m",
25
26
  "--data",
26
- dest="data",
27
27
  type=str,
28
28
  required=True,
29
29
  help="Path to a file in PDB/MMCIF, CCP4/MRC, EM, H5 or a format supported by "
@@ -33,7 +33,6 @@ def parse_args():
33
33
  io_group.add_argument(
34
34
  "-o",
35
35
  "--output",
36
- dest="output",
37
36
  type=str,
38
37
  required=True,
39
38
  help="Path the output should be written to.",
@@ -41,38 +40,34 @@ def parse_args():
41
40
 
42
41
  box_group = parser.add_argument_group("Box")
43
42
  box_group.add_argument(
44
- "--box_size",
45
- dest="box_size",
43
+ "--box-size",
46
44
  type=int,
47
45
  required=False,
48
46
  help="Box size of the output. Defaults to twice the required box size.",
49
47
  )
50
48
  box_group.add_argument(
51
- "--sampling_rate",
52
- dest="sampling_rate",
49
+ "--sampling-rate",
53
50
  type=float,
54
51
  required=True,
55
52
  help="Sampling rate of the output file.",
56
53
  )
57
54
  box_group.add_argument(
58
- "--input_sampling_rate",
59
- dest="input_sampling_rate",
55
+ "--input-sampling-rate",
60
56
  type=float,
61
57
  required=False,
62
- help="Sampling rate of the input file.",
58
+ help="Sampling rate of the input file. Defaults to header for volume "
59
+ "and to --sampling_rate for atomic structures.",
63
60
  )
64
61
 
65
62
  modulation_group = parser.add_argument_group("Modulation")
66
63
  modulation_group.add_argument(
67
- "--invert_contrast",
68
- dest="invert_contrast",
64
+ "--invert-contrast",
69
65
  action="store_true",
70
66
  required=False,
71
67
  help="Inverts the template contrast.",
72
68
  )
73
69
  modulation_group.add_argument(
74
70
  "--lowpass",
75
- dest="lowpass",
76
71
  type=float,
77
72
  required=False,
78
73
  default=None,
@@ -80,14 +75,12 @@ def parse_args():
80
75
  "A value of 0 disables the filter.",
81
76
  )
82
77
  modulation_group.add_argument(
83
- "--no_centering",
84
- dest="no_centering",
78
+ "--no-centering",
85
79
  action="store_true",
86
80
  help="Assumes the template is already centered and omits centering.",
87
81
  )
88
82
  modulation_group.add_argument(
89
83
  "--backend",
90
- dest="backend",
91
84
  type=str,
92
85
  default=None,
93
86
  choices=be.available_backends(),
@@ -96,15 +89,13 @@ def parse_args():
96
89
 
97
90
  alignment_group = parser.add_argument_group("Modulation")
98
91
  alignment_group.add_argument(
99
- "--align_axis",
100
- dest="align_axis",
92
+ "--align-axis",
101
93
  type=int,
102
94
  required=False,
103
95
  help="Align template to given axis, e.g. 2 for z-axis.",
104
96
  )
105
97
  alignment_group.add_argument(
106
- "--align_eigenvector",
107
- dest="align_eigenvector",
98
+ "--align-eigenvector",
108
99
  type=int,
109
100
  required=False,
110
101
  default=0,
@@ -112,8 +103,7 @@ def parse_args():
112
103
  "with numerically largest eigenvalue.",
113
104
  )
114
105
  alignment_group.add_argument(
115
- "--flip_axis",
116
- dest="flip_axis",
106
+ "--flip-axis",
117
107
  action="store_true",
118
108
  required=False,
119
109
  help="Align the template to -axis instead of axis.",
@@ -125,6 +115,7 @@ def parse_args():
125
115
 
126
116
  def main():
127
117
  args = parse_args()
118
+ print_entry()
128
119
 
129
120
  try:
130
121
  data = Structure.from_file(args.data)
@@ -182,19 +173,18 @@ def main():
182
173
  bpf_mask = 1
183
174
  lowpass = 2 * args.sampling_rate if args.lowpass is None else args.lowpass
184
175
  if args.lowpass != 0:
185
- bpf_mask = BandPassFilter(
176
+ bpf_mask = BandPassReconstructed(
186
177
  lowpass=lowpass,
187
178
  highpass=None,
188
179
  use_gaussian=True,
189
180
  return_real_fourier=True,
190
- shape_is_real_fourier=False,
191
181
  sampling_rate=data.sampling_rate,
192
182
  )(shape=data.shape)["data"]
193
- bpf_mask = be.to_numpy_array(bpf_mask)
183
+ bpf_mask = be.to_backend_array(bpf_mask)
194
184
 
195
- data_ft = np.fft.rfftn(data.data, s=data.shape)
196
- data_ft = np.multiply(data_ft, bpf_mask, out=data_ft)
197
- data.data = np.fft.irfftn(data_ft, s=data.shape).real
185
+ data_ft = be.rfftn(be.to_backend_array(data.data), s=data.shape)
186
+ data_ft = be.multiply(data_ft, bpf_mask, out=data_ft)
187
+ data.data = be.to_numpy_array(be.irfftn(data_ft, s=data.shape).real)
198
188
 
199
189
  data = data.resample(args.sampling_rate, method="spline", order=3)
200
190
 
@@ -1,9 +1,10 @@
1
1
  #!python3
2
- """ GUI for identifying adequate template matching filter and masks.
2
+ """
3
+ GUI for identifying suitable masks and analyzing template matchign results.
3
4
 
4
- Copyright (c) 2023 European Molecular Biology Laboratory
5
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
6
 
6
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
8
  """
8
9
  import inspect
9
10
  import argparse
@@ -18,14 +19,14 @@ from magicgui import widgets
18
19
  from numpy.typing import NDArray
19
20
  from napari.layers import Image
20
21
  from scipy.fft import next_fast_len
21
- from qtpy.QtWidgets import QFileDialog
22
22
  from napari.utils.events import EventedList
23
+ from qtpy.QtWidgets import QFileDialog, QMessageBox
23
24
 
24
25
  from tme.backends import backend
25
26
  from tme.rotations import align_vectors
26
- from tme.filters import BandPassFilter, CTF
27
27
  from tme import Preprocessor, Density, Orientations
28
28
  from tme.matching_utils import create_mask, load_pickle
29
+ from tme.filters import BandPassReconstructed, CTFReconstructed
29
30
 
30
31
  preprocessor = Preprocessor()
31
32
  SLIDER_MIN, SLIDER_MAX = 0, 25
@@ -42,17 +43,16 @@ def bandpass_filter(
42
43
  hard_edges: bool = False,
43
44
  sampling_rate=None,
44
45
  ) -> NDArray:
45
- bpf = BandPassFilter(
46
+ bpf = BandPassReconstructed(
46
47
  lowpass=lowpass_angstrom,
47
48
  highpass=highpass_angstrom,
48
49
  sampling_rate=np.max(sampling_rate),
49
50
  use_gaussian=not hard_edges,
50
- shape_is_real_fourier=True,
51
51
  return_real_fourier=True,
52
52
  )
53
53
  template_ft = np.fft.rfftn(template, s=template.shape)
54
54
 
55
- mask = bpf(shape=template_ft.shape)["data"]
55
+ mask = bpf(shape=template.shape)["data"]
56
56
  np.multiply(template_ft, mask, out=template_ft)
57
57
  return np.fft.irfftn(template_ft, s=template.shape).real
58
58
 
@@ -71,15 +71,14 @@ def ctf_filter(
71
71
  fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
72
72
  template_pad = backend.topleft_pad(template, fast_shape)
73
73
  template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
74
- ctf = CTF(
75
- angles=[0],
74
+ ctf = CTFReconstructed(
76
75
  shape=fast_shape,
77
76
  defocus_x=[defocus_angstrom],
78
77
  acceleration_voltage=acceleration_voltage * 1e3,
79
78
  spherical_aberration=spherical_aberration * 1e7,
80
79
  amplitude_contrast=amplitude_contrast,
81
- phase_shift=[phase_shift],
82
- defocus_angle=[defocus_angle],
80
+ phase_shift=phase_shift,
81
+ defocus_angle=defocus_angle,
83
82
  sampling_rate=np.max(sampling_rate),
84
83
  return_real_fourier=True,
85
84
  flip_phase=flip_phase,
@@ -393,7 +392,13 @@ class FilterWidget(widgets.Container):
393
392
  metadata = selected_layer_metadata.copy()
394
393
  if "filter_parameters" not in metadata:
395
394
  metadata["filter_parameters"] = []
396
- metadata["filter_parameters"].append({filter_name: kwargs.copy()})
395
+
396
+ payload = {filter_name: kwargs.copy()}
397
+ if isinstance(metadata["filter_parameters"], dict):
398
+ metadata["filter_parameters"].update(payload)
399
+ else:
400
+ metadata["filter_parameters"].append(payload)
401
+
397
402
  metadata["used_filter"] = filter_name
398
403
  new_layer.metadata = metadata
399
404
 
@@ -456,6 +461,28 @@ def box_mask(
456
461
  )
457
462
 
458
463
 
464
+ def membrane_mask(
465
+ template: NDArray,
466
+ symmetry_axis: int,
467
+ center_x: float,
468
+ center_y: float,
469
+ center_z: float,
470
+ radius: float,
471
+ thickness: float = 1,
472
+ separation: float = 3,
473
+ sigma_decay: float = 0,
474
+ **kwargs,
475
+ ) -> NDArray:
476
+ return create_mask(
477
+ mask_type="membrane",
478
+ shape=template.shape,
479
+ radius=radius,
480
+ thickness=thickness,
481
+ separation=separation,
482
+ sigma_decay=sigma_decay,
483
+ )
484
+
485
+
459
486
  def tube_mask(
460
487
  template: NDArray,
461
488
  symmetry_axis: int,
@@ -585,6 +612,7 @@ class MaskWidget(widgets.Container):
585
612
  "Ellipsoid": ellipsod_mask,
586
613
  "Tube": tube_mask,
587
614
  "Box": box_mask,
615
+ "Membrane": membrane_mask,
588
616
  "Wedge": wedge_mask,
589
617
  "Threshold": threshold_mask,
590
618
  "Lowpass": lowpass_mask,
@@ -1141,6 +1169,8 @@ class MatchingWidget(widgets.Container):
1141
1169
  self.viewer = viewer
1142
1170
  self.dataframes = {}
1143
1171
 
1172
+ self.load_target_checkbox = widgets.CheckBox(text="Load Target", value=False)
1173
+ self.append(self.load_target_checkbox)
1144
1174
  self.import_button = widgets.PushButton(name="Import", text="Import Pickle")
1145
1175
  self.import_button.clicked.connect(self._get_load_path)
1146
1176
 
@@ -1160,6 +1190,26 @@ class MatchingWidget(widgets.Container):
1160
1190
  data = load_pickle(filename)
1161
1191
 
1162
1192
  fname = basename(filename).replace(".pickle", "")
1193
+
1194
+ if self.load_target_checkbox.value:
1195
+ try:
1196
+ target = Density.from_file(data[-1][-1].target)
1197
+ _ = self.viewer.add_image(
1198
+ data=target.data,
1199
+ name=f"{fname}_target",
1200
+ metadata={
1201
+ "origin": target.origin,
1202
+ "sampling_rate": target.sampling_rate,
1203
+ },
1204
+ )
1205
+ except Exception as e:
1206
+ msg = QMessageBox(self.native)
1207
+ msg.setIcon(QMessageBox.Warning)
1208
+ msg.setWindowTitle("Loading Error")
1209
+ msg.setText(str(e))
1210
+ msg.setStandardButtons(QMessageBox.Ok)
1211
+ msg.exec_()
1212
+
1163
1213
  if data[0].ndim == data[2].ndim:
1164
1214
  metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
1165
1215
  _ = self.viewer.add_image(
@@ -1175,11 +1225,8 @@ class MatchingWidget(widgets.Container):
1175
1225
  metadata=metadata,
1176
1226
  )
1177
1227
  return None
1178
- detail = np.zeros_like(data[2])
1179
- else:
1180
- detail = data[3]
1181
1228
 
1182
- point_properties = {"score": data[2], "detail": detail}
1229
+ point_properties = {"score": data[2], "detail": data[3]}
1183
1230
  point_properties["score_scaled"] = np.log1p(
1184
1231
  point_properties["score"] - point_properties["score"].min()
1185
1232
  )
@@ -1192,8 +1239,26 @@ class MatchingWidget(widgets.Container):
1192
1239
  )
1193
1240
 
1194
1241
 
1242
+ class CustomNapariViewer(napari.Viewer):
1243
+ """
1244
+ Custom viewer to ensure 3D image layers are by default shown as xy projection.
1245
+ """
1246
+
1247
+ def add_image(self, data, **kwargs):
1248
+ viewer_ndim = len(self.dims.order)
1249
+ layer = super().add_image(data, **kwargs)
1250
+
1251
+ try:
1252
+ # Set to xy view the first time data is opened
1253
+ if viewer_ndim != 3 and data.ndim == 3:
1254
+ self.dims.order = (2, 0, 1)
1255
+ except Exception:
1256
+ pass
1257
+ return layer
1258
+
1259
+
1195
1260
  def main():
1196
- viewer = napari.Viewer()
1261
+ viewer = CustomNapariViewer()
1197
1262
 
1198
1263
  filter_widget = FilterWidget(preprocessor, viewer)
1199
1264
  mask_widget = MaskWidget(viewer)
@@ -1216,7 +1281,8 @@ def main():
1216
1281
 
1217
1282
  def parse_args():
1218
1283
  parser = argparse.ArgumentParser(
1219
- description="GUI for preparing and analyzing template matching runs."
1284
+ description="GUI for preparing and analyzing template matching runs.",
1285
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1220
1286
  )
1221
1287
  args = parser.parse_args()
1222
1288
  return args