pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,1227 @@
1
+ #!python
2
+ """ GUI for identifying adequate template matching filter and masks.
3
+
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ import inspect
9
+ import argparse
10
+ from os.path import basename
11
+ from typing_extensions import Annotated
12
+ from typing import Tuple, Callable, List
13
+
14
+ import napari
15
+ import numpy as np
16
+ import pandas as pd
17
+ from magicgui import widgets
18
+ from numpy.typing import NDArray
19
+ from napari.layers import Image
20
+ from scipy.fft import next_fast_len
21
+ from qtpy.QtWidgets import QFileDialog
22
+ from napari.utils.events import EventedList
23
+
24
+ from tme.backends import backend
25
+ from tme.rotations import align_vectors
26
+ from tme.filters import BandPassFilter, CTF
27
+ from tme import Preprocessor, Density, Orientations
28
+ from tme.matching_utils import create_mask, load_pickle
29
+
30
+ preprocessor = Preprocessor()
31
+ SLIDER_MIN, SLIDER_MAX = 0, 25
32
+
33
+
34
+ def gaussian_filter(template: NDArray, sigma: float, **kwargs: dict) -> NDArray:
35
+ return preprocessor.gaussian_filter(template=template, sigma=sigma, **kwargs)
36
+
37
+
38
+ def bandpass_filter(
39
+ template: NDArray,
40
+ lowpass_angstrom: float = 30,
41
+ highpass_angstrom: float = 140,
42
+ hard_edges: bool = False,
43
+ sampling_rate=None,
44
+ ) -> NDArray:
45
+ bpf = BandPassFilter(
46
+ lowpass=lowpass_angstrom,
47
+ highpass=highpass_angstrom,
48
+ sampling_rate=np.max(sampling_rate),
49
+ use_gaussian=not hard_edges,
50
+ shape_is_real_fourier=True,
51
+ return_real_fourier=True,
52
+ )
53
+ template_ft = np.fft.rfftn(template, s=template.shape)
54
+
55
+ mask = bpf(shape=template_ft.shape)["data"]
56
+ np.multiply(template_ft, mask, out=template_ft)
57
+ return np.fft.irfftn(template_ft, s=template.shape).real
58
+
59
+
60
+ def ctf_filter(
61
+ template: NDArray,
62
+ defocus_angstrom: float = 30000,
63
+ acceleration_voltage: float = 300,
64
+ spherical_aberration: float = 2.7,
65
+ amplitude_contrast: float = 0.07,
66
+ phase_shift: float = 0,
67
+ defocus_angle: float = 0,
68
+ sampling_rate=None,
69
+ flip_phase: bool = False,
70
+ ) -> NDArray:
71
+ fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
72
+ template_pad = backend.topleft_pad(template, fast_shape)
73
+ template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
74
+ ctf = CTF(
75
+ angles=[0],
76
+ shape=fast_shape,
77
+ defocus_x=[defocus_angstrom],
78
+ acceleration_voltage=acceleration_voltage * 1e3,
79
+ spherical_aberration=spherical_aberration * 1e7,
80
+ amplitude_contrast=amplitude_contrast,
81
+ phase_shift=[phase_shift],
82
+ defocus_angle=[defocus_angle],
83
+ sampling_rate=np.max(sampling_rate),
84
+ return_real_fourier=True,
85
+ flip_phase=flip_phase,
86
+ )
87
+ np.multiply(template_ft, ctf()["data"], out=template_ft)
88
+ template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
89
+ template = backend.topleft_pad(template_pad, template.shape)
90
+ return template
91
+
92
+
93
+ def difference_of_gaussian_filter(
94
+ template: NDArray, sigmas: Tuple[float, float], **kwargs: dict
95
+ ) -> NDArray:
96
+ low_sigma, high_sigma = sigmas
97
+ return preprocessor.difference_of_gaussian_filter(
98
+ template=template, low_sigma=low_sigma, high_sigma=high_sigma, **kwargs
99
+ )
100
+
101
+
102
+ def edge_gaussian_filter(
103
+ template: NDArray,
104
+ sigma: float,
105
+ edge_algorithm: Annotated[
106
+ str,
107
+ {"choices": ["sobel", "prewitt", "laplace", "gaussian", "gaussian_laplace"]},
108
+ ],
109
+ reverse: bool = False,
110
+ **kwargs: dict,
111
+ ) -> NDArray:
112
+ return preprocessor.edge_gaussian_filter(
113
+ template=template,
114
+ sigma=sigma,
115
+ reverse=reverse,
116
+ edge_algorithm=edge_algorithm,
117
+ )
118
+
119
+
120
+ def local_gaussian_filter(
121
+ template: NDArray,
122
+ lbd: float,
123
+ sigma_range: Tuple[float, float],
124
+ gaussian_sigma: float,
125
+ reverse: bool = False,
126
+ **kwargs: dict,
127
+ ) -> NDArray:
128
+ return preprocessor.local_gaussian_filter(
129
+ template=template,
130
+ lbd=lbd,
131
+ sigma_range=sigma_range,
132
+ gaussian_sigma=gaussian_sigma,
133
+ )
134
+
135
+
136
+ def mean(
137
+ template: NDArray,
138
+ width: int,
139
+ **kwargs: dict,
140
+ ) -> NDArray:
141
+ return preprocessor.mean_filter(template=template, width=width)
142
+
143
+
144
+ def wedge(
145
+ template: NDArray,
146
+ tilt_start: float = 40.0,
147
+ tilt_stop: float = 40.0,
148
+ tilt_step: float = 0,
149
+ opening_axis: int = 2,
150
+ tilt_axis: int = 0,
151
+ omit_negative_frequencies: bool = False,
152
+ infinite_plane: bool = False,
153
+ weight_angle: bool = False,
154
+ **kwargs,
155
+ ) -> NDArray:
156
+ template_ft = np.fft.fftn(template)
157
+
158
+ if tilt_step <= 0:
159
+ wedge_mask = preprocessor.continuous_wedge_mask(
160
+ start_tilt=tilt_start,
161
+ stop_tilt=tilt_stop,
162
+ tilt_axis=tilt_axis,
163
+ opening_axis=opening_axis,
164
+ shape=template.shape,
165
+ omit_negative_frequencies=omit_negative_frequencies,
166
+ infinite_plane=infinite_plane,
167
+ )
168
+ else:
169
+ weights = None
170
+ tilt_angles = np.arange(-tilt_start, tilt_stop + tilt_step, tilt_step)
171
+ if weight_angle:
172
+ weights = np.cos(np.radians(tilt_angles))
173
+
174
+ wedge_mask = preprocessor.step_wedge_mask(
175
+ tilt_angles=tilt_angles,
176
+ tilt_axis=tilt_axis,
177
+ opening_axis=opening_axis,
178
+ shape=template.shape,
179
+ weights=weights,
180
+ omit_negative_frequencies=omit_negative_frequencies,
181
+ )
182
+
183
+ np.multiply(template_ft, wedge_mask, out=template_ft)
184
+ template = np.real(np.fft.ifftn(template_ft))
185
+ return template
186
+
187
+
188
+ def compute_power_spectrum(template: NDArray) -> NDArray:
189
+ return np.fft.fftshift(np.log(np.abs(np.fft.fftn(template))))
190
+
191
+
192
+ def invert_contrast(template: NDArray) -> NDArray:
193
+ return template * -1
194
+
195
+
196
+ def widgets_from_function(function: Callable, exclude_params: List = ["self"]):
197
+ """
198
+ Creates list of magicui widgets by inspecting function typing ann
199
+ """
200
+ ret = []
201
+ for name, param in inspect.signature(function).parameters.items():
202
+ if name in exclude_params:
203
+ continue
204
+
205
+ if param.annotation is float:
206
+ widget = widgets.FloatSpinBox(
207
+ name=name,
208
+ value=param.default if param.default != inspect._empty else 0,
209
+ min=SLIDER_MIN,
210
+ step=0.5,
211
+ )
212
+ elif param.annotation == Tuple[float, float]:
213
+ widget = widgets.FloatRangeSlider(
214
+ name=param.name,
215
+ value=(
216
+ param.default
217
+ if param.default != inspect._empty
218
+ else (0.0, SLIDER_MAX / 2)
219
+ ),
220
+ min=SLIDER_MIN,
221
+ max=SLIDER_MAX,
222
+ )
223
+ elif param.annotation is int:
224
+ widget = widgets.SpinBox(
225
+ name=name,
226
+ value=param.default if param.default != inspect._empty else 0,
227
+ )
228
+ elif param.annotation is bool:
229
+ widget = widgets.CheckBox(
230
+ name=name,
231
+ value=param.default if param.default != inspect._empty else False,
232
+ )
233
+ elif hasattr(param.annotation, "__metadata__"):
234
+ metadata = param.annotation.__metadata__[0]
235
+ if "choices" in metadata:
236
+ widget = widgets.ComboBox(
237
+ name=param.name,
238
+ choices=metadata["choices"],
239
+ value=(
240
+ param.default
241
+ if param.default != inspect._empty
242
+ else metadata["choices"][0]
243
+ ),
244
+ )
245
+ else:
246
+ continue
247
+ ret.append(widget)
248
+ return ret
249
+
250
+
251
+ WRAPPED_FUNCTIONS = {
252
+ "gaussian_filter": gaussian_filter,
253
+ "bandpass_filter": bandpass_filter,
254
+ "edge_gaussian_filter": edge_gaussian_filter,
255
+ "local_gaussian_filter": local_gaussian_filter,
256
+ "difference_of_gaussian_filter": difference_of_gaussian_filter,
257
+ "mean_filter": mean,
258
+ "wedge_filter": wedge,
259
+ "power_spectrum": compute_power_spectrum,
260
+ "ctf": ctf_filter,
261
+ "invert_contrast": invert_contrast,
262
+ }
263
+
264
+ EXCLUDED_FUNCTIONS = [
265
+ "apply_method",
266
+ "method_to_id",
267
+ "wedge_mask",
268
+ "fourier_crop",
269
+ "fourier_uncrop",
270
+ "interpolate_box",
271
+ "molmap",
272
+ "local_gaussian_alignment_filter",
273
+ "continuous_wedge_mask",
274
+ "wedge_mask",
275
+ "bandpass_mask",
276
+ ]
277
+
278
+
279
+ class FilterWidget(widgets.Container):
280
+ def __init__(self, preprocessor, viewer):
281
+ super().__init__(layout="vertical")
282
+
283
+ self.preprocessor = preprocessor
284
+ self.viewer = viewer
285
+ self.name_mapping = {}
286
+ self.action_widgets = []
287
+
288
+ self.layer_dropdown = widgets.ComboBox(
289
+ name="Target Layer", choices=self._get_layer_names()
290
+ )
291
+ self.append(self.layer_dropdown)
292
+ self.viewer.layers.events.inserted.connect(self._update_layer_dropdown)
293
+ self.viewer.layers.events.removed.connect(self._update_layer_dropdown)
294
+
295
+ self.method_dropdown = widgets.ComboBox(
296
+ name="Choose Filter", choices=self._get_method_names()
297
+ )
298
+ self.method_dropdown.changed.connect(self._on_method_changed)
299
+ self.append(self.method_dropdown)
300
+
301
+ self.apply_btn = widgets.PushButton(text="Apply Filter", enabled=False)
302
+ self.apply_btn.changed.connect(self._action)
303
+ self.append(self.apply_btn)
304
+
305
+ # Create GUI for initially selected filtering method
306
+ self._on_method_changed(None)
307
+
308
+ def _get_method_names(self):
309
+ method_names = [
310
+ name
311
+ for name, member in inspect.getmembers(self.preprocessor, inspect.ismethod)
312
+ if not name.startswith("_") and name not in EXCLUDED_FUNCTIONS
313
+ ]
314
+ method_names.extend(list(WRAPPED_FUNCTIONS.keys()))
315
+ method_names = list(set(method_names))
316
+
317
+ sanitized_names = [self._sanitize_name(name) for name in method_names]
318
+ self.name_mapping.update(dict(zip(sanitized_names, method_names)))
319
+ sanitized_names.sort()
320
+
321
+ return sanitized_names
322
+
323
+ def _sanitize_name(self, name: str) -> str:
324
+ # Replace underscores with spaces and capitalize each word
325
+ removes = ["blur", "filter"]
326
+ for remove in removes:
327
+ name = name.replace(remove, "")
328
+ return name.strip().replace("_", " ").title()
329
+
330
+ def _desanitize_name(self, name: str) -> str:
331
+ name = name.lower().strip()
332
+ for function_name, _ in inspect.getmembers(self.preprocessor, inspect.ismethod):
333
+ if function_name.startswith(name):
334
+ return function_name
335
+ return name
336
+
337
+ def _get_function(self, name: str):
338
+ function = WRAPPED_FUNCTIONS.get(name, None)
339
+ if not function:
340
+ function = getattr(self.preprocessor, name, None)
341
+ return function
342
+
343
+ def _on_method_changed(self, event=None):
344
+ # Clear previous parameter widgets
345
+ for widget in self.action_widgets:
346
+ self.remove(widget)
347
+ self.action_widgets.clear()
348
+
349
+ function_name = self.name_mapping.get(self.method_dropdown.value)
350
+ function = self._get_function(function_name)
351
+
352
+ widgets = widgets_from_function(function, exclude_params=["self", "template"])
353
+ for widget in widgets:
354
+ self.action_widgets.append(widget)
355
+ self.insert(-1, widget)
356
+
357
+ def _update_layer_dropdown(self, event: EventedList):
358
+ """Update the dropdown menu when layers change."""
359
+ self.layer_dropdown.choices = self._get_layer_names()
360
+ self.apply_btn.enabled = bool(self.viewer.layers)
361
+
362
+ def _get_layer_names(self):
363
+ """Return list of layer names in the viewer."""
364
+ return sorted([layer.name for layer in self.viewer.layers])
365
+
366
+ def _action(self, event):
367
+ selected_layer = self.viewer.layers[self.layer_dropdown.value]
368
+ selected_layer_metadata = selected_layer.metadata.copy()
369
+ kwargs = {widget.name: widget.value for widget in self.action_widgets}
370
+
371
+ function_name = self.name_mapping.get(self.method_dropdown.value)
372
+ function = self._get_function(function_name)
373
+
374
+ if "sampling_rate" in inspect.getfullargspec(function).args:
375
+ kwargs["sampling_rate"] = selected_layer_metadata["sampling_rate"]
376
+
377
+ processed_data = function(selected_layer.data, **kwargs)
378
+
379
+ new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
380
+
381
+ if new_layer_name in self.viewer.layers:
382
+ selected_layer = self.viewer.layers[new_layer_name]
383
+
384
+ filter_name = self._desanitize_name(self.method_dropdown.value)
385
+ used_filter = selected_layer.metadata.get("used_filter", False)
386
+ if used_filter == filter_name:
387
+ selected_layer.data = processed_data
388
+ else:
389
+ new_layer = self.viewer.add_image(
390
+ data=processed_data,
391
+ name=new_layer_name,
392
+ )
393
+ metadata = selected_layer_metadata.copy()
394
+ if "filter_parameters" not in metadata:
395
+ metadata["filter_parameters"] = []
396
+ metadata["filter_parameters"].append({filter_name: kwargs.copy()})
397
+ metadata["used_filter"] = filter_name
398
+ new_layer.metadata = metadata
399
+
400
+
401
+ def sphere_mask(
402
+ template: NDArray,
403
+ center_x: float,
404
+ center_y: float,
405
+ center_z: float,
406
+ radius: float,
407
+ sigma_decay: float = 0,
408
+ **kwargs,
409
+ ) -> NDArray:
410
+ return create_mask(
411
+ mask_type="ellipse",
412
+ shape=template.shape,
413
+ center=(center_x, center_y, center_z),
414
+ radius=radius,
415
+ sigma_decay=sigma_decay,
416
+ )
417
+
418
+
419
+ def ellipsod_mask(
420
+ template: NDArray,
421
+ center_x: float,
422
+ center_y: float,
423
+ center_z: float,
424
+ radius_x: float,
425
+ radius_y: float,
426
+ radius_z: float,
427
+ sigma_decay: float = 0,
428
+ **kwargs,
429
+ ) -> NDArray:
430
+ return create_mask(
431
+ mask_type="ellipse",
432
+ shape=template.shape,
433
+ center=(center_x, center_y, center_z),
434
+ radius=(radius_x, radius_y, radius_z),
435
+ sigma_decay=sigma_decay,
436
+ )
437
+
438
+
439
+ def box_mask(
440
+ template: NDArray,
441
+ center_x: float,
442
+ center_y: float,
443
+ center_z: float,
444
+ height_x: int,
445
+ height_y: int,
446
+ height_z: int,
447
+ sigma_decay: float = 0,
448
+ **kwargs,
449
+ ) -> NDArray:
450
+ return create_mask(
451
+ mask_type="box",
452
+ shape=template.shape,
453
+ center=(center_x, center_y, center_z),
454
+ height=(height_x, height_y, height_z),
455
+ sigma_decay=sigma_decay,
456
+ )
457
+
458
+
459
+ def tube_mask(
460
+ template: NDArray,
461
+ symmetry_axis: int,
462
+ center_x: float,
463
+ center_y: float,
464
+ center_z: float,
465
+ inner_radius: float,
466
+ outer_radius: float,
467
+ height: int,
468
+ sigma_decay: float = 0,
469
+ **kwargs,
470
+ ) -> NDArray:
471
+ return create_mask(
472
+ mask_type="tube",
473
+ shape=template.shape,
474
+ symmetry_axis=symmetry_axis,
475
+ base_center=(center_x, center_y, center_z),
476
+ inner_radius=inner_radius,
477
+ outer_radius=outer_radius,
478
+ height=height,
479
+ sigma_decay=sigma_decay,
480
+ )
481
+
482
+
483
+ def wedge_mask(
484
+ template: NDArray,
485
+ tilt_start: float = 40.0,
486
+ tilt_stop: float = 40.0,
487
+ tilt_step: float = 0,
488
+ opening_axis: int = 2,
489
+ tilt_axis: int = 0,
490
+ omit_negative_frequencies: bool = False,
491
+ infinite_plane: bool = False,
492
+ weight_angle: bool = False,
493
+ **kwargs,
494
+ ) -> NDArray:
495
+ if tilt_step <= 0:
496
+ wedge_mask = preprocessor.continuous_wedge_mask(
497
+ start_tilt=tilt_start,
498
+ stop_tilt=tilt_stop,
499
+ tilt_axis=tilt_axis,
500
+ opening_axis=opening_axis,
501
+ shape=template.shape,
502
+ omit_negative_frequencies=omit_negative_frequencies,
503
+ infinite_plane=infinite_plane,
504
+ )
505
+ wedge_mask = np.fft.fftshift(wedge_mask)
506
+ return wedge_mask
507
+
508
+ weights = None
509
+ tilt_angles = np.arange(-tilt_start, tilt_stop + tilt_step, tilt_step)
510
+ if weight_angle:
511
+ weights = np.cos(np.radians(tilt_angles))
512
+
513
+ wedge_mask = preprocessor.step_wedge_mask(
514
+ tilt_angles=tilt_angles,
515
+ tilt_axis=tilt_axis,
516
+ opening_axis=opening_axis,
517
+ shape=template.shape,
518
+ weights=weights,
519
+ omit_negative_frequencies=omit_negative_frequencies,
520
+ )
521
+
522
+ wedge_mask = np.fft.fftshift(wedge_mask)
523
+ return wedge_mask
524
+
525
+
526
+ def threshold_mask(
527
+ template: NDArray,
528
+ invert: bool = False,
529
+ standard_deviation: float = 5.0,
530
+ sigma: float = 0.0,
531
+ **kwargs,
532
+ ) -> NDArray:
533
+ template_mean = template.mean()
534
+ template_deviation = standard_deviation * template.std()
535
+ upper = template_mean + template_deviation
536
+ lower = template_mean - template_deviation
537
+ mask = np.logical_or(template <= lower, template >= upper)
538
+
539
+ if sigma != 0:
540
+ mask_filter = preprocessor.gaussian_filter(template=mask * 1.0, sigma=sigma)
541
+ mask = np.add(mask, (1 - mask) * mask_filter)
542
+ mask[mask < np.exp(-np.square(sigma))] = 0
543
+
544
+ if invert:
545
+ np.invert(mask, out=mask)
546
+
547
+ return mask
548
+
549
+
550
+ def lowpass_mask(template: NDArray, sigma: float = 1.0, **kwargs):
551
+ template = template / template.max()
552
+ template = (template > np.exp(-2)) * 128.0
553
+ template = preprocessor.gaussian_filter(template=template, sigma=sigma)
554
+ mask = template > np.exp(-2)
555
+
556
+ return mask
557
+
558
+
559
+ def shape_mask(template, shapes_layer, expansion_dim):
560
+ ret = np.zeros_like(template)
561
+ mask_shape = tuple(x for i, x in enumerate(template.shape) if i != expansion_dim)
562
+ masks = shapes_layer.to_masks(mask_shape=mask_shape)
563
+ for index, shape_type in enumerate(shapes_layer.shape_type):
564
+ mask = np.expand_dims(masks[index], axis=expansion_dim)
565
+ mask = np.repeat(
566
+ mask, repeats=template.shape[expansion_dim], axis=expansion_dim
567
+ )
568
+ np.logical_or(ret, mask, out=ret)
569
+
570
+ return ret
571
+
572
+
573
+ class MaskWidget(widgets.Container):
574
+ def __init__(self, viewer):
575
+ super().__init__(layout="vertical")
576
+
577
+ self.viewer = viewer
578
+ self.action_widgets = []
579
+
580
+ self.action_button = widgets.PushButton(text="Create mask", enabled=False)
581
+ self.action_button.changed.connect(self._action)
582
+
583
+ self.methods = {
584
+ "Sphere": sphere_mask,
585
+ "Ellipsoid": ellipsod_mask,
586
+ "Tube": tube_mask,
587
+ "Box": box_mask,
588
+ "Wedge": wedge_mask,
589
+ "Threshold": threshold_mask,
590
+ "Lowpass": lowpass_mask,
591
+ "Shape": shape_mask,
592
+ }
593
+
594
+ self.method_dropdown = widgets.ComboBox(
595
+ name="Choose Mask", choices=list(self.methods.keys())
596
+ )
597
+ self.method_dropdown.changed.connect(self._on_method_changed)
598
+
599
+ self.percentile_range_edit = widgets.FloatSpinBox(
600
+ name="Data Quantile", min=0, max=100, value=0, step=2
601
+ )
602
+
603
+ self.adapt_button = widgets.PushButton(text="Adapt to layer", enabled=False)
604
+ self.adapt_button.changed.connect(self._update_initial_values)
605
+ self.viewer.layers.selection.events.active.connect(
606
+ self._update_action_button_state
607
+ )
608
+
609
+ self.density_field = widgets.Label()
610
+ # self.density_field.value = f"Positive Density in Mask: {0:.2f}%"
611
+
612
+ self.shapes_layer_dropdown = widgets.ComboBox(
613
+ name="shapes_layer", choices=self._get_shape_layers()
614
+ )
615
+ self.viewer.layers.events.inserted.connect(self._update_shape_layer_choices)
616
+ self.viewer.layers.events.removed.connect(self._update_shape_layer_choices)
617
+
618
+ self.append(self.method_dropdown)
619
+ self.append(self.adapt_button)
620
+ self.append(self.percentile_range_edit)
621
+
622
+ self.append(self.action_button)
623
+ self.append(self.density_field)
624
+
625
+ # Create GUI for initially selected filtering method
626
+ self._on_method_changed(None)
627
+
628
+ def _update_action_button_state(self, event):
629
+ self.action_button.enabled = bool(self.viewer.layers.selection.active)
630
+ self.adapt_button.enabled = bool(self.viewer.layers.selection.active)
631
+
632
+ def _update_initial_values(self, event=None):
633
+ active_layer = self.viewer.layers.selection.active
634
+
635
+ data = active_layer.data.copy()
636
+ cutoff = np.quantile(data, self.percentile_range_edit.value / 100)
637
+ cutoff = max(cutoff, np.finfo(np.float32).resolution)
638
+ data[data < cutoff] = 0
639
+
640
+ center_of_mass = Density.center_of_mass(np.abs(data), 0)
641
+ coordinates = np.array(np.where(data > 0))
642
+ coordinates_min = coordinates.min(axis=1)
643
+ coordinates_max = coordinates.max(axis=1)
644
+ coordinates_heights = coordinates_max - coordinates_min
645
+ coordinate_radius = np.divide(coordinates_heights, 2)
646
+ center_of_mass = coordinate_radius + coordinates_min
647
+
648
+ defaults = dict(zip(["center_x", "center_y", "center_z"], center_of_mass))
649
+ defaults.update(
650
+ dict(zip(["radius_x", "radius_y", "radius_z"], coordinate_radius))
651
+ )
652
+ defaults.update(
653
+ dict(zip(["height_x", "height_y", "height_z"], coordinates_heights))
654
+ )
655
+
656
+ defaults["radius"] = np.max(coordinate_radius)
657
+ defaults["inner_radius"] = np.min(coordinate_radius)
658
+ defaults["outer_radius"] = np.max(coordinate_radius)
659
+ defaults["height"] = np.max(coordinates_heights)
660
+
661
+ for widget in self.action_widgets:
662
+ if widget.name in defaults:
663
+ widget.value = defaults[widget.name]
664
+
665
+ def _on_method_changed(self, event=None):
666
+ for widget in self.action_widgets:
667
+ self.remove(widget)
668
+ self.action_widgets.clear()
669
+
670
+ function = self.methods.get(self.method_dropdown.value)
671
+ function_widgets = widgets_from_function(function)
672
+ for widget in function_widgets:
673
+ self.action_widgets.append(widget)
674
+ self.insert(1, widget)
675
+
676
+ for name, param in inspect.signature(function).parameters.items():
677
+ if name == "shapes_layer":
678
+ self.action_widgets.append(self.shapes_layer_dropdown)
679
+ self.insert(1, self.shapes_layer_dropdown)
680
+
681
+ def _get_shape_layers(self):
682
+ layers = [
683
+ layer.name
684
+ for layer in self.viewer.layers
685
+ if isinstance(layer, napari.layers.Shapes)
686
+ ]
687
+ return layers
688
+
689
+ def _update_shape_layer_choices(self, event):
690
+ """Update the choices in the shapes layer dropdown."""
691
+ self.shapes_layer_dropdown.choices = self._get_shape_layers()
692
+
693
+ def _action(self):
694
+ function = self.methods.get(self.method_dropdown.value)
695
+
696
+ selected_layer = self.viewer.layers.selection.active
697
+ kwargs = {widget.name: widget.value for widget in self.action_widgets}
698
+
699
+ if "shapes_layer" in kwargs:
700
+ layer_name = kwargs["shapes_layer"]
701
+ if layer_name not in self.viewer.layers:
702
+ return None
703
+ kwargs["shapes_layer"] = self.viewer.layers[layer_name]
704
+ kwargs["expansion_dim"] = self.viewer.dims.order[0]
705
+
706
+ processed_data = function(template=selected_layer.data, **kwargs)
707
+
708
+ new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
709
+
710
+ if new_layer_name in self.viewer.layers:
711
+ selected_layer = self.viewer.layers[new_layer_name]
712
+
713
+ processed_data = processed_data.astype(np.float32)
714
+ metadata = selected_layer.metadata
715
+ mask = metadata.get("mask", False)
716
+ if mask == self.method_dropdown.value:
717
+ selected_layer.data = processed_data
718
+ else:
719
+ new_layer = self.viewer.add_image(
720
+ data=processed_data,
721
+ name=new_layer_name,
722
+ )
723
+ metadata = selected_layer.metadata.copy()
724
+ metadata["filter_parameters"] = {self.method_dropdown.value: kwargs.copy()}
725
+ metadata["mask"] = self.method_dropdown.value
726
+ metadata["origin_layer"] = selected_layer.name
727
+ new_layer.metadata = metadata
728
+
729
+ if self.method_dropdown.value == "Shape":
730
+ new_layer.metadata = {}
731
+
732
+ # origin_layer = metadata["origin_layer"]
733
+ # if origin_layer in self.viewer.layers:
734
+ # origin_layer = self.viewer.layers[origin_layer]
735
+ # if np.allclose(origin_layer.data.shape, processed_data.shape):
736
+ # in_mask = np.sum(np.fmax(origin_layer.data * processed_data, 0))
737
+ # in_mask /= np.sum(np.fmax(origin_layer.data, 0))
738
+ # in_mask *= 100
739
+ # self.density_field.value = f"Positive Density in Mask: {in_mask:.2f}%"
740
+
741
+
742
+ class AlignmentWidget(widgets.Container):
743
+ def __init__(self, viewer):
744
+ super().__init__(layout="vertical")
745
+
746
+ self.viewer = viewer
747
+
748
+ align_button = widgets.PushButton(text="Align to axis", enabled=True)
749
+ self.align_axis = widgets.ComboBox(
750
+ value=None, nullable=True, choices=self._get_active_layer_dims
751
+ )
752
+ self.viewer.layers.selection.events.changed.connect(self._update_align_axis)
753
+
754
+ align_button.changed.connect(self._align_with_axis)
755
+ container = widgets.Container(
756
+ widgets=[align_button, self.align_axis], layout="horizontal"
757
+ )
758
+ self.append(container)
759
+
760
+ rot90 = widgets.PushButton(text="Rotate 90", enabled=True)
761
+ rotneg90 = widgets.PushButton(text="Rotate -90", enabled=True)
762
+
763
+ rot90.changed.connect(self._rot90)
764
+ rotneg90.changed.connect(self._rotneg90)
765
+
766
+ container = widgets.Container(widgets=[rot90, rotneg90], layout="horizontal")
767
+ self.append(container)
768
+
769
+ def _rot90(self, swap_axes: bool = False):
770
+ active_layer = self.viewer.layers.selection.active
771
+ if active_layer is None:
772
+ return None
773
+ elif self.viewer.dims.ndisplay != 2:
774
+ return None
775
+
776
+ align_axis = self.align_axis.value
777
+ if self.align_axis.value is None:
778
+ align_axis = self.viewer.dims.order[0]
779
+
780
+ axes = [
781
+ align_axis,
782
+ *[i for i in range(len(self.viewer.dims.order)) if i != align_axis],
783
+ ][:2]
784
+ axes = axes[::-1] if swap_axes else axes
785
+ active_layer.data = np.rot90(active_layer.data, k=1, axes=axes)
786
+
787
+ def _rotneg90(self):
788
+ return self._rot90(swap_axes=True)
789
+
790
+ def _get_active_layer_dims(self, *args):
791
+ active_layer = self.viewer.layers.selection.active
792
+ if active_layer is None:
793
+ return ()
794
+ try:
795
+ return [i for i in range(active_layer.data.ndim)]
796
+ except Exception:
797
+ return ()
798
+
799
+ def _update_align_axis(self, *args):
800
+ self.align_axis.choices = self._get_active_layer_dims()
801
+
802
+ def _align_with_axis(self):
803
+ active_layer = self.viewer.layers.selection.active
804
+
805
+ if self.align_axis.value is None:
806
+ return None
807
+
808
+ if active_layer.metadata.get("is_aligned", None) == self.align_axis.value:
809
+ return None
810
+
811
+ alignment_axis = np.zeros(active_layer.data.ndim)
812
+ alignment_axis[int(self.align_axis.value)] = 1
813
+
814
+ coords = np.array(np.where(active_layer.data > 0)).T
815
+ centered_coords = coords - np.mean(coords, axis=0)
816
+ cov_matrix = np.cov(centered_coords, rowvar=False)
817
+ eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
818
+ principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
819
+
820
+ rotation_matrix = align_vectors(principal_eigenvector, alignment_axis)
821
+ rotated_data, _ = backend.rigid_transform(
822
+ arr=active_layer.data,
823
+ rotation_matrix=rotation_matrix,
824
+ use_geometric_center=False,
825
+ order=1,
826
+ )
827
+ eps = np.finfo(rotated_data.dtype).eps
828
+ rotated_data[rotated_data < eps] = 0
829
+
830
+ active_layer.metadata["is_aligned"] = int(self.align_axis.value)
831
+ active_layer.data = rotated_data
832
+
833
+
834
+ class ExportWidget(widgets.Container):
835
+ def __init__(self, viewer):
836
+ super().__init__(layout="vertical")
837
+
838
+ self.viewer = viewer
839
+ self.selected_filename = ""
840
+
841
+ horizontal_container = widgets.Container(layout="horizontal")
842
+
843
+ self.gzip_output = widgets.CheckBox(name="gzip", value=False, label="gzip")
844
+ self.export_button = widgets.PushButton(name="Export", text="Export")
845
+ self.export_button.clicked.connect(self._get_save_path)
846
+
847
+ horizontal_container.append(self.export_button)
848
+ horizontal_container.append(self.gzip_output)
849
+
850
+ self.append(horizontal_container)
851
+
852
+ self.export_button.enabled = bool(self.viewer.layers.selection.active)
853
+ self.viewer.layers.selection.events.active.connect(
854
+ self._update_export_button_state
855
+ )
856
+
857
+ def _get_save_path(self, event):
858
+ path, _ = QFileDialog().getSaveFileName(
859
+ self.native,
860
+ "Save As...",
861
+ "",
862
+ "MRC Files (*.mrc)",
863
+ )
864
+ if path:
865
+ self.selected_filename = path
866
+ self._export_data()
867
+
868
+ def _update_export_button_state(self, event):
869
+ """Update the enabled state of the export button based on the active layer."""
870
+ self.export_button.enabled = bool(self.viewer.layers.selection.active)
871
+
872
+ def _export_data(self):
873
+ selected_layer = self.viewer.layers.selection.active
874
+ if selected_layer and isinstance(selected_layer, Image):
875
+ selected_layer.metadata["write_gzip"] = self.gzip_output.value
876
+ selected_layer.save(path=self.selected_filename)
877
+
878
+
879
+ class PointCloudWidget(widgets.Container):
880
+ def __init__(self, viewer):
881
+ super().__init__(layout="vertical")
882
+
883
+ self.viewer = viewer
884
+ self.dataframes = {}
885
+ self.selected_category = -1
886
+
887
+ self.import_button = widgets.PushButton(
888
+ name="Import", text="Import Point Cloud"
889
+ )
890
+ self.import_button.clicked.connect(self._get_load_path)
891
+
892
+ self.export_button = widgets.PushButton(
893
+ name="Export", text="Export Point Cloud"
894
+ )
895
+ self.export_button.clicked.connect(self._export_point_cloud)
896
+ self.export_button.enabled = False
897
+
898
+ self.score_filter_container = widgets.Container(
899
+ name="Score Filter", layout="vertical"
900
+ )
901
+ self.score_range_slider = widgets.FloatRangeSlider(
902
+ min=0,
903
+ max=100,
904
+ step=0.1,
905
+ value=(0, 100),
906
+ readout=False,
907
+ )
908
+ self.score_range_slider.changed.connect(self._update_point_visibility)
909
+ self.score_filter_container.append(self.score_range_slider)
910
+
911
+ self.annotation_container = widgets.Container(name="Label", layout="horizontal")
912
+ self.positive_button = widgets.PushButton(name="Positive", text="Positive")
913
+ self.negative_button = widgets.PushButton(name="Negative", text="Negative")
914
+ self.positive_button.clicked.connect(self._set_positive)
915
+ self.negative_button.clicked.connect(self._set_negative)
916
+ self.annotation_container.append(self.positive_button)
917
+ self.annotation_container.append(self.negative_button)
918
+
919
+ self.face_color_select = widgets.ComboBox(
920
+ name="Color", choices=["Label", "Score"], value=None, nullable=True
921
+ )
922
+ self.face_color_select.changed.connect(self._update_face_color_mode)
923
+
924
+ self.append(self.import_button)
925
+ self.append(self.export_button)
926
+ self.append(self.score_filter_container)
927
+ self.append(self.annotation_container)
928
+ self.append(self.face_color_select)
929
+
930
+ self.viewer.layers.selection.events.changed.connect(self._update_buttons)
931
+
932
+ self.viewer.layers.events.inserted.connect(self._initialize_points_layer)
933
+
934
+ def _update_point_visibility(self, event=None):
935
+ min_percentile, max_percentile = self.score_range_slider.value
936
+
937
+ for layer in self.viewer.layers:
938
+ if not isinstance(layer, napari.layers.Points):
939
+ continue
940
+
941
+ if "score" not in layer.properties:
942
+ continue
943
+
944
+ scores = layer.properties["score"]
945
+ layer.selected_data = set()
946
+
947
+ min_score = np.percentile(scores, min_percentile)
948
+ max_score = np.percentile(scores, max_percentile)
949
+ visible_mask = (scores >= min_score) & (scores <= max_score)
950
+ layer.shown = visible_mask
951
+
952
+ def _update_face_color_mode(self, event: str = None):
953
+ for layer in self.viewer.layers:
954
+ if not isinstance(layer, napari.layers.Points):
955
+ continue
956
+
957
+ layer.face_color = "white"
958
+ if event == "Label":
959
+ if len(layer.properties.get("detail", ())) == 0:
960
+ continue
961
+ layer.face_color = "detail"
962
+ layer.face_color_cycle = {
963
+ -1: "grey",
964
+ 0: "red",
965
+ 1: "green",
966
+ }
967
+ layer.face_color_mode = "cycle"
968
+ elif event == "Score":
969
+ if len(layer.properties.get("score_scaled", ())) == 0:
970
+ continue
971
+ layer.face_color = "score_scaled"
972
+ layer.face_colormap = "turbo"
973
+ layer.face_color_mode = "colormap"
974
+
975
+ layer.refresh_colors()
976
+
977
+ return None
978
+
979
+ def _set_positive(self, event):
980
+ self.selected_category = 1 if self.selected_category != 1 else -1
981
+ self._update_annotation_buttons()
982
+
983
+ def _set_negative(self, event):
984
+ self.selected_category = 0 if self.selected_category != 0 else -1
985
+ self._update_annotation_buttons()
986
+
987
+ def _update_annotation_buttons(self):
988
+ selected_style = "background-color: darkgrey"
989
+ default_style = "background-color: none"
990
+
991
+ self.positive_button.native.setStyleSheet(
992
+ selected_style if self.selected_category == 1 else default_style
993
+ )
994
+ self.negative_button.native.setStyleSheet(
995
+ selected_style if self.selected_category == 0 else default_style
996
+ )
997
+
998
+ def _initialize_points_layer(self, event):
999
+ layer = event.value
1000
+ if not isinstance(layer, napari.layers.Points):
1001
+ return
1002
+ if len(layer.properties) == 0:
1003
+ layer.properties = {"detail": [-1]}
1004
+
1005
+ if "detail" not in layer.properties:
1006
+ layer["detail"] = [-1]
1007
+
1008
+ layer.mouse_drag_callbacks.append(self._on_click)
1009
+ return None
1010
+
1011
+ def _on_click(self, layer, event):
1012
+ if layer.mode == "add":
1013
+ layer.current_properties["detail"][-1] = self.selected_category
1014
+ elif layer.mode == "select":
1015
+ for index in layer.selected_data:
1016
+ layer.properties["detail"][index] = self.selected_category
1017
+
1018
+ # TODO: Check whether current face color is the desired one already
1019
+ self._update_face_color_mode(self.face_color_select.value)
1020
+ layer.refresh_colors()
1021
+
1022
+ def _update_buttons(self, event):
1023
+ is_pointcloud = isinstance(
1024
+ self.viewer.layers.selection.active, napari.layers.Points
1025
+ )
1026
+ if self.viewer.layers.selection.active and is_pointcloud:
1027
+ self.export_button.enabled = True
1028
+ else:
1029
+ self.export_button.enabled = False
1030
+
1031
+ def _export_point_cloud(self, event):
1032
+ filename, _ = QFileDialog().getSaveFileName(
1033
+ self.native,
1034
+ "Save Point Cloud File...",
1035
+ "",
1036
+ "Orientations (*.tsv *.star);; All Files (*)",
1037
+ )
1038
+
1039
+ if not filename:
1040
+ return None
1041
+
1042
+ layer = self.viewer.layers.selection.active
1043
+ if layer and isinstance(layer, napari.layers.Points):
1044
+ original_dataframe = self.dataframes.get(
1045
+ layer.name,
1046
+ pd.DataFrame(columns=["x", "y", "z", "euler_x", "euler_y", "euler_z"]),
1047
+ )
1048
+ export_data = pd.DataFrame(layer.data, columns=["x", "y", "z"])
1049
+ merged_data = pd.merge(
1050
+ export_data, original_dataframe, on=["x", "y", "z"], how="left"
1051
+ )
1052
+
1053
+ merged_data["z"] = merged_data["z"].astype(int)
1054
+ merged_data["y"] = merged_data["y"].astype(int)
1055
+ merged_data["x"] = merged_data["x"].astype(int)
1056
+
1057
+ euler_columns = ["euler_z", "euler_y", "euler_x"]
1058
+ for col in euler_columns:
1059
+ if col not in merged_data.columns:
1060
+ continue
1061
+ merged_data[col] = merged_data[col].fillna(0)
1062
+
1063
+ if "score" in merged_data.columns:
1064
+ merged_data["score"] = merged_data["score"].fillna(1)
1065
+ merged_data["detail"] = layer.properties["detail"]
1066
+ merged_data = merged_data[layer.shown]
1067
+
1068
+ translations = np.stack(
1069
+ (merged_data["x"], merged_data["y"], merged_data["z"]), axis=1
1070
+ )
1071
+ rotations = np.stack(
1072
+ (
1073
+ merged_data["euler_x"],
1074
+ merged_data["euler_y"],
1075
+ merged_data["euler_z"],
1076
+ ),
1077
+ axis=1,
1078
+ )
1079
+ orientations = Orientations(
1080
+ translations=translations,
1081
+ rotations=rotations,
1082
+ scores=merged_data["score"],
1083
+ details=merged_data["detail"],
1084
+ )
1085
+ orientations.to_file(filename)
1086
+
1087
+ def _get_load_path(self, event):
1088
+ filename, _ = QFileDialog().getOpenFileName(
1089
+ self.native,
1090
+ "Open Point Cloud File...",
1091
+ "",
1092
+ "Orientations (*.tsv *.star);; All Files (*)",
1093
+ )
1094
+ if filename:
1095
+ self._load_point_cloud(filename)
1096
+
1097
+ def _load_point_cloud(self, filename):
1098
+ orientations = Orientations.from_file(filename)
1099
+
1100
+ data = {
1101
+ "x": orientations.translations[:, 0],
1102
+ "y": orientations.translations[:, 1],
1103
+ "z": orientations.translations[:, 2],
1104
+ "euler_x": orientations.rotations[:, 0],
1105
+ "euler_y": orientations.rotations[:, 1],
1106
+ "euler_z": orientations.rotations[:, 2],
1107
+ "score": orientations.scores,
1108
+ "detail": orientations.details,
1109
+ }
1110
+
1111
+ dataframe = pd.DataFrame.from_dict(data)
1112
+ layer_name = filename.split("/")[-1]
1113
+
1114
+ if "score" not in dataframe.columns:
1115
+ dataframe["score"] = 1
1116
+
1117
+ if "detail" not in dataframe.columns:
1118
+ dataframe["detail"] = -1
1119
+
1120
+ point_properties = {
1121
+ "score": np.array(dataframe["score"].values),
1122
+ "detail": np.array(dataframe["detail"].values),
1123
+ }
1124
+ point_properties["score_scaled"] = np.log1p(
1125
+ point_properties["score"] - point_properties["score"].min()
1126
+ )
1127
+ self.score_range_slider.value = (0, 100)
1128
+ self.viewer.add_points(
1129
+ dataframe[["x", "y", "z"]].values,
1130
+ size=10,
1131
+ properties=point_properties,
1132
+ name=layer_name,
1133
+ )
1134
+ self.dataframes[layer_name] = dataframe
1135
+
1136
+
1137
+ class MatchingWidget(widgets.Container):
1138
+ def __init__(self, viewer):
1139
+ super().__init__(layout="vertical")
1140
+
1141
+ self.viewer = viewer
1142
+ self.dataframes = {}
1143
+
1144
+ self.import_button = widgets.PushButton(name="Import", text="Import Pickle")
1145
+ self.import_button.clicked.connect(self._get_load_path)
1146
+
1147
+ self.append(self.import_button)
1148
+
1149
+ def _get_load_path(self, event):
1150
+ filename, _ = QFileDialog.getOpenFileName(
1151
+ self.native,
1152
+ "Open Pickle File...",
1153
+ "",
1154
+ "Pickle Files (*.pickle);;All Files (*)",
1155
+ )
1156
+ if filename:
1157
+ self._load_data(filename)
1158
+
1159
+ def _load_data(self, filename):
1160
+ data = load_pickle(filename)
1161
+
1162
+ fname = basename(filename).replace(".pickle", "")
1163
+ if data[0].ndim == data[2].ndim:
1164
+ metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
1165
+ _ = self.viewer.add_image(
1166
+ data=data[2],
1167
+ name=f"{fname}_rotations",
1168
+ colormap="orange",
1169
+ metadata=metadata,
1170
+ )
1171
+ _ = self.viewer.add_image(
1172
+ data=data[0],
1173
+ name=f"{fname}_scores",
1174
+ colormap="turbo",
1175
+ metadata=metadata,
1176
+ )
1177
+ return None
1178
+ detail = np.zeros_like(data[2])
1179
+ else:
1180
+ detail = data[3]
1181
+
1182
+ point_properties = {"score": data[2], "detail": detail}
1183
+ point_properties["score_scaled"] = np.log1p(
1184
+ point_properties["score"] - point_properties["score"].min()
1185
+ )
1186
+ layer_name = f"{fname}_candidates"
1187
+ self.viewer.add_points(
1188
+ data[0],
1189
+ size=10,
1190
+ properties=point_properties,
1191
+ name=layer_name,
1192
+ )
1193
+
1194
+
1195
+ def main():
1196
+ viewer = napari.Viewer()
1197
+
1198
+ filter_widget = FilterWidget(preprocessor, viewer)
1199
+ mask_widget = MaskWidget(viewer)
1200
+ export_widget = ExportWidget(viewer)
1201
+ point_cloud = PointCloudWidget(viewer)
1202
+ matching_widget = MatchingWidget(viewer)
1203
+ alignment_widget = AlignmentWidget(viewer)
1204
+
1205
+ viewer.window.add_dock_widget(widget=filter_widget, name="Preprocess", area="right")
1206
+ viewer.window.add_dock_widget(
1207
+ widget=alignment_widget, name="Alignment", area="right"
1208
+ )
1209
+ viewer.window.add_dock_widget(widget=mask_widget, name="Mask", area="right")
1210
+ viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
1211
+ viewer.window.add_dock_widget(widget=point_cloud, name="PointCloud", area="left")
1212
+ viewer.window.add_dock_widget(widget=matching_widget, name="Matching", area="left")
1213
+
1214
+ napari.run()
1215
+
1216
+
1217
+ def parse_args():
1218
+ parser = argparse.ArgumentParser(
1219
+ description="GUI for preparing and analyzing template matching runs."
1220
+ )
1221
+ args = parser.parse_args()
1222
+ return args
1223
+
1224
+
1225
+ if __name__ == "__main__":
1226
+ parse_args()
1227
+ main()