pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
@@ -0,0 +1,1227 @@
|
|
1
|
+
#!python
|
2
|
+
""" GUI for identifying adequate template matching filter and masks.
|
3
|
+
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
import inspect
|
9
|
+
import argparse
|
10
|
+
from os.path import basename
|
11
|
+
from typing_extensions import Annotated
|
12
|
+
from typing import Tuple, Callable, List
|
13
|
+
|
14
|
+
import napari
|
15
|
+
import numpy as np
|
16
|
+
import pandas as pd
|
17
|
+
from magicgui import widgets
|
18
|
+
from numpy.typing import NDArray
|
19
|
+
from napari.layers import Image
|
20
|
+
from scipy.fft import next_fast_len
|
21
|
+
from qtpy.QtWidgets import QFileDialog
|
22
|
+
from napari.utils.events import EventedList
|
23
|
+
|
24
|
+
from tme.backends import backend
|
25
|
+
from tme.rotations import align_vectors
|
26
|
+
from tme.filters import BandPassFilter, CTF
|
27
|
+
from tme import Preprocessor, Density, Orientations
|
28
|
+
from tme.matching_utils import create_mask, load_pickle
|
29
|
+
|
30
|
+
preprocessor = Preprocessor()
|
31
|
+
SLIDER_MIN, SLIDER_MAX = 0, 25
|
32
|
+
|
33
|
+
|
34
|
+
def gaussian_filter(template: NDArray, sigma: float, **kwargs: dict) -> NDArray:
|
35
|
+
return preprocessor.gaussian_filter(template=template, sigma=sigma, **kwargs)
|
36
|
+
|
37
|
+
|
38
|
+
def bandpass_filter(
|
39
|
+
template: NDArray,
|
40
|
+
lowpass_angstrom: float = 30,
|
41
|
+
highpass_angstrom: float = 140,
|
42
|
+
hard_edges: bool = False,
|
43
|
+
sampling_rate=None,
|
44
|
+
) -> NDArray:
|
45
|
+
bpf = BandPassFilter(
|
46
|
+
lowpass=lowpass_angstrom,
|
47
|
+
highpass=highpass_angstrom,
|
48
|
+
sampling_rate=np.max(sampling_rate),
|
49
|
+
use_gaussian=not hard_edges,
|
50
|
+
shape_is_real_fourier=True,
|
51
|
+
return_real_fourier=True,
|
52
|
+
)
|
53
|
+
template_ft = np.fft.rfftn(template, s=template.shape)
|
54
|
+
|
55
|
+
mask = bpf(shape=template_ft.shape)["data"]
|
56
|
+
np.multiply(template_ft, mask, out=template_ft)
|
57
|
+
return np.fft.irfftn(template_ft, s=template.shape).real
|
58
|
+
|
59
|
+
|
60
|
+
def ctf_filter(
|
61
|
+
template: NDArray,
|
62
|
+
defocus_angstrom: float = 30000,
|
63
|
+
acceleration_voltage: float = 300,
|
64
|
+
spherical_aberration: float = 2.7,
|
65
|
+
amplitude_contrast: float = 0.07,
|
66
|
+
phase_shift: float = 0,
|
67
|
+
defocus_angle: float = 0,
|
68
|
+
sampling_rate=None,
|
69
|
+
flip_phase: bool = False,
|
70
|
+
) -> NDArray:
|
71
|
+
fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
|
72
|
+
template_pad = backend.topleft_pad(template, fast_shape)
|
73
|
+
template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
|
74
|
+
ctf = CTF(
|
75
|
+
angles=[0],
|
76
|
+
shape=fast_shape,
|
77
|
+
defocus_x=[defocus_angstrom],
|
78
|
+
acceleration_voltage=acceleration_voltage * 1e3,
|
79
|
+
spherical_aberration=spherical_aberration * 1e7,
|
80
|
+
amplitude_contrast=amplitude_contrast,
|
81
|
+
phase_shift=[phase_shift],
|
82
|
+
defocus_angle=[defocus_angle],
|
83
|
+
sampling_rate=np.max(sampling_rate),
|
84
|
+
return_real_fourier=True,
|
85
|
+
flip_phase=flip_phase,
|
86
|
+
)
|
87
|
+
np.multiply(template_ft, ctf()["data"], out=template_ft)
|
88
|
+
template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
|
89
|
+
template = backend.topleft_pad(template_pad, template.shape)
|
90
|
+
return template
|
91
|
+
|
92
|
+
|
93
|
+
def difference_of_gaussian_filter(
|
94
|
+
template: NDArray, sigmas: Tuple[float, float], **kwargs: dict
|
95
|
+
) -> NDArray:
|
96
|
+
low_sigma, high_sigma = sigmas
|
97
|
+
return preprocessor.difference_of_gaussian_filter(
|
98
|
+
template=template, low_sigma=low_sigma, high_sigma=high_sigma, **kwargs
|
99
|
+
)
|
100
|
+
|
101
|
+
|
102
|
+
def edge_gaussian_filter(
|
103
|
+
template: NDArray,
|
104
|
+
sigma: float,
|
105
|
+
edge_algorithm: Annotated[
|
106
|
+
str,
|
107
|
+
{"choices": ["sobel", "prewitt", "laplace", "gaussian", "gaussian_laplace"]},
|
108
|
+
],
|
109
|
+
reverse: bool = False,
|
110
|
+
**kwargs: dict,
|
111
|
+
) -> NDArray:
|
112
|
+
return preprocessor.edge_gaussian_filter(
|
113
|
+
template=template,
|
114
|
+
sigma=sigma,
|
115
|
+
reverse=reverse,
|
116
|
+
edge_algorithm=edge_algorithm,
|
117
|
+
)
|
118
|
+
|
119
|
+
|
120
|
+
def local_gaussian_filter(
|
121
|
+
template: NDArray,
|
122
|
+
lbd: float,
|
123
|
+
sigma_range: Tuple[float, float],
|
124
|
+
gaussian_sigma: float,
|
125
|
+
reverse: bool = False,
|
126
|
+
**kwargs: dict,
|
127
|
+
) -> NDArray:
|
128
|
+
return preprocessor.local_gaussian_filter(
|
129
|
+
template=template,
|
130
|
+
lbd=lbd,
|
131
|
+
sigma_range=sigma_range,
|
132
|
+
gaussian_sigma=gaussian_sigma,
|
133
|
+
)
|
134
|
+
|
135
|
+
|
136
|
+
def mean(
|
137
|
+
template: NDArray,
|
138
|
+
width: int,
|
139
|
+
**kwargs: dict,
|
140
|
+
) -> NDArray:
|
141
|
+
return preprocessor.mean_filter(template=template, width=width)
|
142
|
+
|
143
|
+
|
144
|
+
def wedge(
|
145
|
+
template: NDArray,
|
146
|
+
tilt_start: float = 40.0,
|
147
|
+
tilt_stop: float = 40.0,
|
148
|
+
tilt_step: float = 0,
|
149
|
+
opening_axis: int = 2,
|
150
|
+
tilt_axis: int = 0,
|
151
|
+
omit_negative_frequencies: bool = False,
|
152
|
+
infinite_plane: bool = False,
|
153
|
+
weight_angle: bool = False,
|
154
|
+
**kwargs,
|
155
|
+
) -> NDArray:
|
156
|
+
template_ft = np.fft.fftn(template)
|
157
|
+
|
158
|
+
if tilt_step <= 0:
|
159
|
+
wedge_mask = preprocessor.continuous_wedge_mask(
|
160
|
+
start_tilt=tilt_start,
|
161
|
+
stop_tilt=tilt_stop,
|
162
|
+
tilt_axis=tilt_axis,
|
163
|
+
opening_axis=opening_axis,
|
164
|
+
shape=template.shape,
|
165
|
+
omit_negative_frequencies=omit_negative_frequencies,
|
166
|
+
infinite_plane=infinite_plane,
|
167
|
+
)
|
168
|
+
else:
|
169
|
+
weights = None
|
170
|
+
tilt_angles = np.arange(-tilt_start, tilt_stop + tilt_step, tilt_step)
|
171
|
+
if weight_angle:
|
172
|
+
weights = np.cos(np.radians(tilt_angles))
|
173
|
+
|
174
|
+
wedge_mask = preprocessor.step_wedge_mask(
|
175
|
+
tilt_angles=tilt_angles,
|
176
|
+
tilt_axis=tilt_axis,
|
177
|
+
opening_axis=opening_axis,
|
178
|
+
shape=template.shape,
|
179
|
+
weights=weights,
|
180
|
+
omit_negative_frequencies=omit_negative_frequencies,
|
181
|
+
)
|
182
|
+
|
183
|
+
np.multiply(template_ft, wedge_mask, out=template_ft)
|
184
|
+
template = np.real(np.fft.ifftn(template_ft))
|
185
|
+
return template
|
186
|
+
|
187
|
+
|
188
|
+
def compute_power_spectrum(template: NDArray) -> NDArray:
|
189
|
+
return np.fft.fftshift(np.log(np.abs(np.fft.fftn(template))))
|
190
|
+
|
191
|
+
|
192
|
+
def invert_contrast(template: NDArray) -> NDArray:
|
193
|
+
return template * -1
|
194
|
+
|
195
|
+
|
196
|
+
def widgets_from_function(function: Callable, exclude_params: List = ["self"]):
|
197
|
+
"""
|
198
|
+
Creates list of magicui widgets by inspecting function typing ann
|
199
|
+
"""
|
200
|
+
ret = []
|
201
|
+
for name, param in inspect.signature(function).parameters.items():
|
202
|
+
if name in exclude_params:
|
203
|
+
continue
|
204
|
+
|
205
|
+
if param.annotation is float:
|
206
|
+
widget = widgets.FloatSpinBox(
|
207
|
+
name=name,
|
208
|
+
value=param.default if param.default != inspect._empty else 0,
|
209
|
+
min=SLIDER_MIN,
|
210
|
+
step=0.5,
|
211
|
+
)
|
212
|
+
elif param.annotation == Tuple[float, float]:
|
213
|
+
widget = widgets.FloatRangeSlider(
|
214
|
+
name=param.name,
|
215
|
+
value=(
|
216
|
+
param.default
|
217
|
+
if param.default != inspect._empty
|
218
|
+
else (0.0, SLIDER_MAX / 2)
|
219
|
+
),
|
220
|
+
min=SLIDER_MIN,
|
221
|
+
max=SLIDER_MAX,
|
222
|
+
)
|
223
|
+
elif param.annotation is int:
|
224
|
+
widget = widgets.SpinBox(
|
225
|
+
name=name,
|
226
|
+
value=param.default if param.default != inspect._empty else 0,
|
227
|
+
)
|
228
|
+
elif param.annotation is bool:
|
229
|
+
widget = widgets.CheckBox(
|
230
|
+
name=name,
|
231
|
+
value=param.default if param.default != inspect._empty else False,
|
232
|
+
)
|
233
|
+
elif hasattr(param.annotation, "__metadata__"):
|
234
|
+
metadata = param.annotation.__metadata__[0]
|
235
|
+
if "choices" in metadata:
|
236
|
+
widget = widgets.ComboBox(
|
237
|
+
name=param.name,
|
238
|
+
choices=metadata["choices"],
|
239
|
+
value=(
|
240
|
+
param.default
|
241
|
+
if param.default != inspect._empty
|
242
|
+
else metadata["choices"][0]
|
243
|
+
),
|
244
|
+
)
|
245
|
+
else:
|
246
|
+
continue
|
247
|
+
ret.append(widget)
|
248
|
+
return ret
|
249
|
+
|
250
|
+
|
251
|
+
WRAPPED_FUNCTIONS = {
|
252
|
+
"gaussian_filter": gaussian_filter,
|
253
|
+
"bandpass_filter": bandpass_filter,
|
254
|
+
"edge_gaussian_filter": edge_gaussian_filter,
|
255
|
+
"local_gaussian_filter": local_gaussian_filter,
|
256
|
+
"difference_of_gaussian_filter": difference_of_gaussian_filter,
|
257
|
+
"mean_filter": mean,
|
258
|
+
"wedge_filter": wedge,
|
259
|
+
"power_spectrum": compute_power_spectrum,
|
260
|
+
"ctf": ctf_filter,
|
261
|
+
"invert_contrast": invert_contrast,
|
262
|
+
}
|
263
|
+
|
264
|
+
EXCLUDED_FUNCTIONS = [
|
265
|
+
"apply_method",
|
266
|
+
"method_to_id",
|
267
|
+
"wedge_mask",
|
268
|
+
"fourier_crop",
|
269
|
+
"fourier_uncrop",
|
270
|
+
"interpolate_box",
|
271
|
+
"molmap",
|
272
|
+
"local_gaussian_alignment_filter",
|
273
|
+
"continuous_wedge_mask",
|
274
|
+
"wedge_mask",
|
275
|
+
"bandpass_mask",
|
276
|
+
]
|
277
|
+
|
278
|
+
|
279
|
+
class FilterWidget(widgets.Container):
|
280
|
+
def __init__(self, preprocessor, viewer):
|
281
|
+
super().__init__(layout="vertical")
|
282
|
+
|
283
|
+
self.preprocessor = preprocessor
|
284
|
+
self.viewer = viewer
|
285
|
+
self.name_mapping = {}
|
286
|
+
self.action_widgets = []
|
287
|
+
|
288
|
+
self.layer_dropdown = widgets.ComboBox(
|
289
|
+
name="Target Layer", choices=self._get_layer_names()
|
290
|
+
)
|
291
|
+
self.append(self.layer_dropdown)
|
292
|
+
self.viewer.layers.events.inserted.connect(self._update_layer_dropdown)
|
293
|
+
self.viewer.layers.events.removed.connect(self._update_layer_dropdown)
|
294
|
+
|
295
|
+
self.method_dropdown = widgets.ComboBox(
|
296
|
+
name="Choose Filter", choices=self._get_method_names()
|
297
|
+
)
|
298
|
+
self.method_dropdown.changed.connect(self._on_method_changed)
|
299
|
+
self.append(self.method_dropdown)
|
300
|
+
|
301
|
+
self.apply_btn = widgets.PushButton(text="Apply Filter", enabled=False)
|
302
|
+
self.apply_btn.changed.connect(self._action)
|
303
|
+
self.append(self.apply_btn)
|
304
|
+
|
305
|
+
# Create GUI for initially selected filtering method
|
306
|
+
self._on_method_changed(None)
|
307
|
+
|
308
|
+
def _get_method_names(self):
|
309
|
+
method_names = [
|
310
|
+
name
|
311
|
+
for name, member in inspect.getmembers(self.preprocessor, inspect.ismethod)
|
312
|
+
if not name.startswith("_") and name not in EXCLUDED_FUNCTIONS
|
313
|
+
]
|
314
|
+
method_names.extend(list(WRAPPED_FUNCTIONS.keys()))
|
315
|
+
method_names = list(set(method_names))
|
316
|
+
|
317
|
+
sanitized_names = [self._sanitize_name(name) for name in method_names]
|
318
|
+
self.name_mapping.update(dict(zip(sanitized_names, method_names)))
|
319
|
+
sanitized_names.sort()
|
320
|
+
|
321
|
+
return sanitized_names
|
322
|
+
|
323
|
+
def _sanitize_name(self, name: str) -> str:
|
324
|
+
# Replace underscores with spaces and capitalize each word
|
325
|
+
removes = ["blur", "filter"]
|
326
|
+
for remove in removes:
|
327
|
+
name = name.replace(remove, "")
|
328
|
+
return name.strip().replace("_", " ").title()
|
329
|
+
|
330
|
+
def _desanitize_name(self, name: str) -> str:
|
331
|
+
name = name.lower().strip()
|
332
|
+
for function_name, _ in inspect.getmembers(self.preprocessor, inspect.ismethod):
|
333
|
+
if function_name.startswith(name):
|
334
|
+
return function_name
|
335
|
+
return name
|
336
|
+
|
337
|
+
def _get_function(self, name: str):
|
338
|
+
function = WRAPPED_FUNCTIONS.get(name, None)
|
339
|
+
if not function:
|
340
|
+
function = getattr(self.preprocessor, name, None)
|
341
|
+
return function
|
342
|
+
|
343
|
+
def _on_method_changed(self, event=None):
|
344
|
+
# Clear previous parameter widgets
|
345
|
+
for widget in self.action_widgets:
|
346
|
+
self.remove(widget)
|
347
|
+
self.action_widgets.clear()
|
348
|
+
|
349
|
+
function_name = self.name_mapping.get(self.method_dropdown.value)
|
350
|
+
function = self._get_function(function_name)
|
351
|
+
|
352
|
+
widgets = widgets_from_function(function, exclude_params=["self", "template"])
|
353
|
+
for widget in widgets:
|
354
|
+
self.action_widgets.append(widget)
|
355
|
+
self.insert(-1, widget)
|
356
|
+
|
357
|
+
def _update_layer_dropdown(self, event: EventedList):
|
358
|
+
"""Update the dropdown menu when layers change."""
|
359
|
+
self.layer_dropdown.choices = self._get_layer_names()
|
360
|
+
self.apply_btn.enabled = bool(self.viewer.layers)
|
361
|
+
|
362
|
+
def _get_layer_names(self):
|
363
|
+
"""Return list of layer names in the viewer."""
|
364
|
+
return sorted([layer.name for layer in self.viewer.layers])
|
365
|
+
|
366
|
+
def _action(self, event):
|
367
|
+
selected_layer = self.viewer.layers[self.layer_dropdown.value]
|
368
|
+
selected_layer_metadata = selected_layer.metadata.copy()
|
369
|
+
kwargs = {widget.name: widget.value for widget in self.action_widgets}
|
370
|
+
|
371
|
+
function_name = self.name_mapping.get(self.method_dropdown.value)
|
372
|
+
function = self._get_function(function_name)
|
373
|
+
|
374
|
+
if "sampling_rate" in inspect.getfullargspec(function).args:
|
375
|
+
kwargs["sampling_rate"] = selected_layer_metadata["sampling_rate"]
|
376
|
+
|
377
|
+
processed_data = function(selected_layer.data, **kwargs)
|
378
|
+
|
379
|
+
new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
|
380
|
+
|
381
|
+
if new_layer_name in self.viewer.layers:
|
382
|
+
selected_layer = self.viewer.layers[new_layer_name]
|
383
|
+
|
384
|
+
filter_name = self._desanitize_name(self.method_dropdown.value)
|
385
|
+
used_filter = selected_layer.metadata.get("used_filter", False)
|
386
|
+
if used_filter == filter_name:
|
387
|
+
selected_layer.data = processed_data
|
388
|
+
else:
|
389
|
+
new_layer = self.viewer.add_image(
|
390
|
+
data=processed_data,
|
391
|
+
name=new_layer_name,
|
392
|
+
)
|
393
|
+
metadata = selected_layer_metadata.copy()
|
394
|
+
if "filter_parameters" not in metadata:
|
395
|
+
metadata["filter_parameters"] = []
|
396
|
+
metadata["filter_parameters"].append({filter_name: kwargs.copy()})
|
397
|
+
metadata["used_filter"] = filter_name
|
398
|
+
new_layer.metadata = metadata
|
399
|
+
|
400
|
+
|
401
|
+
def sphere_mask(
|
402
|
+
template: NDArray,
|
403
|
+
center_x: float,
|
404
|
+
center_y: float,
|
405
|
+
center_z: float,
|
406
|
+
radius: float,
|
407
|
+
sigma_decay: float = 0,
|
408
|
+
**kwargs,
|
409
|
+
) -> NDArray:
|
410
|
+
return create_mask(
|
411
|
+
mask_type="ellipse",
|
412
|
+
shape=template.shape,
|
413
|
+
center=(center_x, center_y, center_z),
|
414
|
+
radius=radius,
|
415
|
+
sigma_decay=sigma_decay,
|
416
|
+
)
|
417
|
+
|
418
|
+
|
419
|
+
def ellipsod_mask(
|
420
|
+
template: NDArray,
|
421
|
+
center_x: float,
|
422
|
+
center_y: float,
|
423
|
+
center_z: float,
|
424
|
+
radius_x: float,
|
425
|
+
radius_y: float,
|
426
|
+
radius_z: float,
|
427
|
+
sigma_decay: float = 0,
|
428
|
+
**kwargs,
|
429
|
+
) -> NDArray:
|
430
|
+
return create_mask(
|
431
|
+
mask_type="ellipse",
|
432
|
+
shape=template.shape,
|
433
|
+
center=(center_x, center_y, center_z),
|
434
|
+
radius=(radius_x, radius_y, radius_z),
|
435
|
+
sigma_decay=sigma_decay,
|
436
|
+
)
|
437
|
+
|
438
|
+
|
439
|
+
def box_mask(
|
440
|
+
template: NDArray,
|
441
|
+
center_x: float,
|
442
|
+
center_y: float,
|
443
|
+
center_z: float,
|
444
|
+
height_x: int,
|
445
|
+
height_y: int,
|
446
|
+
height_z: int,
|
447
|
+
sigma_decay: float = 0,
|
448
|
+
**kwargs,
|
449
|
+
) -> NDArray:
|
450
|
+
return create_mask(
|
451
|
+
mask_type="box",
|
452
|
+
shape=template.shape,
|
453
|
+
center=(center_x, center_y, center_z),
|
454
|
+
height=(height_x, height_y, height_z),
|
455
|
+
sigma_decay=sigma_decay,
|
456
|
+
)
|
457
|
+
|
458
|
+
|
459
|
+
def tube_mask(
|
460
|
+
template: NDArray,
|
461
|
+
symmetry_axis: int,
|
462
|
+
center_x: float,
|
463
|
+
center_y: float,
|
464
|
+
center_z: float,
|
465
|
+
inner_radius: float,
|
466
|
+
outer_radius: float,
|
467
|
+
height: int,
|
468
|
+
sigma_decay: float = 0,
|
469
|
+
**kwargs,
|
470
|
+
) -> NDArray:
|
471
|
+
return create_mask(
|
472
|
+
mask_type="tube",
|
473
|
+
shape=template.shape,
|
474
|
+
symmetry_axis=symmetry_axis,
|
475
|
+
base_center=(center_x, center_y, center_z),
|
476
|
+
inner_radius=inner_radius,
|
477
|
+
outer_radius=outer_radius,
|
478
|
+
height=height,
|
479
|
+
sigma_decay=sigma_decay,
|
480
|
+
)
|
481
|
+
|
482
|
+
|
483
|
+
def wedge_mask(
|
484
|
+
template: NDArray,
|
485
|
+
tilt_start: float = 40.0,
|
486
|
+
tilt_stop: float = 40.0,
|
487
|
+
tilt_step: float = 0,
|
488
|
+
opening_axis: int = 2,
|
489
|
+
tilt_axis: int = 0,
|
490
|
+
omit_negative_frequencies: bool = False,
|
491
|
+
infinite_plane: bool = False,
|
492
|
+
weight_angle: bool = False,
|
493
|
+
**kwargs,
|
494
|
+
) -> NDArray:
|
495
|
+
if tilt_step <= 0:
|
496
|
+
wedge_mask = preprocessor.continuous_wedge_mask(
|
497
|
+
start_tilt=tilt_start,
|
498
|
+
stop_tilt=tilt_stop,
|
499
|
+
tilt_axis=tilt_axis,
|
500
|
+
opening_axis=opening_axis,
|
501
|
+
shape=template.shape,
|
502
|
+
omit_negative_frequencies=omit_negative_frequencies,
|
503
|
+
infinite_plane=infinite_plane,
|
504
|
+
)
|
505
|
+
wedge_mask = np.fft.fftshift(wedge_mask)
|
506
|
+
return wedge_mask
|
507
|
+
|
508
|
+
weights = None
|
509
|
+
tilt_angles = np.arange(-tilt_start, tilt_stop + tilt_step, tilt_step)
|
510
|
+
if weight_angle:
|
511
|
+
weights = np.cos(np.radians(tilt_angles))
|
512
|
+
|
513
|
+
wedge_mask = preprocessor.step_wedge_mask(
|
514
|
+
tilt_angles=tilt_angles,
|
515
|
+
tilt_axis=tilt_axis,
|
516
|
+
opening_axis=opening_axis,
|
517
|
+
shape=template.shape,
|
518
|
+
weights=weights,
|
519
|
+
omit_negative_frequencies=omit_negative_frequencies,
|
520
|
+
)
|
521
|
+
|
522
|
+
wedge_mask = np.fft.fftshift(wedge_mask)
|
523
|
+
return wedge_mask
|
524
|
+
|
525
|
+
|
526
|
+
def threshold_mask(
|
527
|
+
template: NDArray,
|
528
|
+
invert: bool = False,
|
529
|
+
standard_deviation: float = 5.0,
|
530
|
+
sigma: float = 0.0,
|
531
|
+
**kwargs,
|
532
|
+
) -> NDArray:
|
533
|
+
template_mean = template.mean()
|
534
|
+
template_deviation = standard_deviation * template.std()
|
535
|
+
upper = template_mean + template_deviation
|
536
|
+
lower = template_mean - template_deviation
|
537
|
+
mask = np.logical_or(template <= lower, template >= upper)
|
538
|
+
|
539
|
+
if sigma != 0:
|
540
|
+
mask_filter = preprocessor.gaussian_filter(template=mask * 1.0, sigma=sigma)
|
541
|
+
mask = np.add(mask, (1 - mask) * mask_filter)
|
542
|
+
mask[mask < np.exp(-np.square(sigma))] = 0
|
543
|
+
|
544
|
+
if invert:
|
545
|
+
np.invert(mask, out=mask)
|
546
|
+
|
547
|
+
return mask
|
548
|
+
|
549
|
+
|
550
|
+
def lowpass_mask(template: NDArray, sigma: float = 1.0, **kwargs):
|
551
|
+
template = template / template.max()
|
552
|
+
template = (template > np.exp(-2)) * 128.0
|
553
|
+
template = preprocessor.gaussian_filter(template=template, sigma=sigma)
|
554
|
+
mask = template > np.exp(-2)
|
555
|
+
|
556
|
+
return mask
|
557
|
+
|
558
|
+
|
559
|
+
def shape_mask(template, shapes_layer, expansion_dim):
|
560
|
+
ret = np.zeros_like(template)
|
561
|
+
mask_shape = tuple(x for i, x in enumerate(template.shape) if i != expansion_dim)
|
562
|
+
masks = shapes_layer.to_masks(mask_shape=mask_shape)
|
563
|
+
for index, shape_type in enumerate(shapes_layer.shape_type):
|
564
|
+
mask = np.expand_dims(masks[index], axis=expansion_dim)
|
565
|
+
mask = np.repeat(
|
566
|
+
mask, repeats=template.shape[expansion_dim], axis=expansion_dim
|
567
|
+
)
|
568
|
+
np.logical_or(ret, mask, out=ret)
|
569
|
+
|
570
|
+
return ret
|
571
|
+
|
572
|
+
|
573
|
+
class MaskWidget(widgets.Container):
|
574
|
+
def __init__(self, viewer):
|
575
|
+
super().__init__(layout="vertical")
|
576
|
+
|
577
|
+
self.viewer = viewer
|
578
|
+
self.action_widgets = []
|
579
|
+
|
580
|
+
self.action_button = widgets.PushButton(text="Create mask", enabled=False)
|
581
|
+
self.action_button.changed.connect(self._action)
|
582
|
+
|
583
|
+
self.methods = {
|
584
|
+
"Sphere": sphere_mask,
|
585
|
+
"Ellipsoid": ellipsod_mask,
|
586
|
+
"Tube": tube_mask,
|
587
|
+
"Box": box_mask,
|
588
|
+
"Wedge": wedge_mask,
|
589
|
+
"Threshold": threshold_mask,
|
590
|
+
"Lowpass": lowpass_mask,
|
591
|
+
"Shape": shape_mask,
|
592
|
+
}
|
593
|
+
|
594
|
+
self.method_dropdown = widgets.ComboBox(
|
595
|
+
name="Choose Mask", choices=list(self.methods.keys())
|
596
|
+
)
|
597
|
+
self.method_dropdown.changed.connect(self._on_method_changed)
|
598
|
+
|
599
|
+
self.percentile_range_edit = widgets.FloatSpinBox(
|
600
|
+
name="Data Quantile", min=0, max=100, value=0, step=2
|
601
|
+
)
|
602
|
+
|
603
|
+
self.adapt_button = widgets.PushButton(text="Adapt to layer", enabled=False)
|
604
|
+
self.adapt_button.changed.connect(self._update_initial_values)
|
605
|
+
self.viewer.layers.selection.events.active.connect(
|
606
|
+
self._update_action_button_state
|
607
|
+
)
|
608
|
+
|
609
|
+
self.density_field = widgets.Label()
|
610
|
+
# self.density_field.value = f"Positive Density in Mask: {0:.2f}%"
|
611
|
+
|
612
|
+
self.shapes_layer_dropdown = widgets.ComboBox(
|
613
|
+
name="shapes_layer", choices=self._get_shape_layers()
|
614
|
+
)
|
615
|
+
self.viewer.layers.events.inserted.connect(self._update_shape_layer_choices)
|
616
|
+
self.viewer.layers.events.removed.connect(self._update_shape_layer_choices)
|
617
|
+
|
618
|
+
self.append(self.method_dropdown)
|
619
|
+
self.append(self.adapt_button)
|
620
|
+
self.append(self.percentile_range_edit)
|
621
|
+
|
622
|
+
self.append(self.action_button)
|
623
|
+
self.append(self.density_field)
|
624
|
+
|
625
|
+
# Create GUI for initially selected filtering method
|
626
|
+
self._on_method_changed(None)
|
627
|
+
|
628
|
+
def _update_action_button_state(self, event):
|
629
|
+
self.action_button.enabled = bool(self.viewer.layers.selection.active)
|
630
|
+
self.adapt_button.enabled = bool(self.viewer.layers.selection.active)
|
631
|
+
|
632
|
+
def _update_initial_values(self, event=None):
|
633
|
+
active_layer = self.viewer.layers.selection.active
|
634
|
+
|
635
|
+
data = active_layer.data.copy()
|
636
|
+
cutoff = np.quantile(data, self.percentile_range_edit.value / 100)
|
637
|
+
cutoff = max(cutoff, np.finfo(np.float32).resolution)
|
638
|
+
data[data < cutoff] = 0
|
639
|
+
|
640
|
+
center_of_mass = Density.center_of_mass(np.abs(data), 0)
|
641
|
+
coordinates = np.array(np.where(data > 0))
|
642
|
+
coordinates_min = coordinates.min(axis=1)
|
643
|
+
coordinates_max = coordinates.max(axis=1)
|
644
|
+
coordinates_heights = coordinates_max - coordinates_min
|
645
|
+
coordinate_radius = np.divide(coordinates_heights, 2)
|
646
|
+
center_of_mass = coordinate_radius + coordinates_min
|
647
|
+
|
648
|
+
defaults = dict(zip(["center_x", "center_y", "center_z"], center_of_mass))
|
649
|
+
defaults.update(
|
650
|
+
dict(zip(["radius_x", "radius_y", "radius_z"], coordinate_radius))
|
651
|
+
)
|
652
|
+
defaults.update(
|
653
|
+
dict(zip(["height_x", "height_y", "height_z"], coordinates_heights))
|
654
|
+
)
|
655
|
+
|
656
|
+
defaults["radius"] = np.max(coordinate_radius)
|
657
|
+
defaults["inner_radius"] = np.min(coordinate_radius)
|
658
|
+
defaults["outer_radius"] = np.max(coordinate_radius)
|
659
|
+
defaults["height"] = np.max(coordinates_heights)
|
660
|
+
|
661
|
+
for widget in self.action_widgets:
|
662
|
+
if widget.name in defaults:
|
663
|
+
widget.value = defaults[widget.name]
|
664
|
+
|
665
|
+
def _on_method_changed(self, event=None):
|
666
|
+
for widget in self.action_widgets:
|
667
|
+
self.remove(widget)
|
668
|
+
self.action_widgets.clear()
|
669
|
+
|
670
|
+
function = self.methods.get(self.method_dropdown.value)
|
671
|
+
function_widgets = widgets_from_function(function)
|
672
|
+
for widget in function_widgets:
|
673
|
+
self.action_widgets.append(widget)
|
674
|
+
self.insert(1, widget)
|
675
|
+
|
676
|
+
for name, param in inspect.signature(function).parameters.items():
|
677
|
+
if name == "shapes_layer":
|
678
|
+
self.action_widgets.append(self.shapes_layer_dropdown)
|
679
|
+
self.insert(1, self.shapes_layer_dropdown)
|
680
|
+
|
681
|
+
def _get_shape_layers(self):
|
682
|
+
layers = [
|
683
|
+
layer.name
|
684
|
+
for layer in self.viewer.layers
|
685
|
+
if isinstance(layer, napari.layers.Shapes)
|
686
|
+
]
|
687
|
+
return layers
|
688
|
+
|
689
|
+
def _update_shape_layer_choices(self, event):
|
690
|
+
"""Update the choices in the shapes layer dropdown."""
|
691
|
+
self.shapes_layer_dropdown.choices = self._get_shape_layers()
|
692
|
+
|
693
|
+
def _action(self):
|
694
|
+
function = self.methods.get(self.method_dropdown.value)
|
695
|
+
|
696
|
+
selected_layer = self.viewer.layers.selection.active
|
697
|
+
kwargs = {widget.name: widget.value for widget in self.action_widgets}
|
698
|
+
|
699
|
+
if "shapes_layer" in kwargs:
|
700
|
+
layer_name = kwargs["shapes_layer"]
|
701
|
+
if layer_name not in self.viewer.layers:
|
702
|
+
return None
|
703
|
+
kwargs["shapes_layer"] = self.viewer.layers[layer_name]
|
704
|
+
kwargs["expansion_dim"] = self.viewer.dims.order[0]
|
705
|
+
|
706
|
+
processed_data = function(template=selected_layer.data, **kwargs)
|
707
|
+
|
708
|
+
new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
|
709
|
+
|
710
|
+
if new_layer_name in self.viewer.layers:
|
711
|
+
selected_layer = self.viewer.layers[new_layer_name]
|
712
|
+
|
713
|
+
processed_data = processed_data.astype(np.float32)
|
714
|
+
metadata = selected_layer.metadata
|
715
|
+
mask = metadata.get("mask", False)
|
716
|
+
if mask == self.method_dropdown.value:
|
717
|
+
selected_layer.data = processed_data
|
718
|
+
else:
|
719
|
+
new_layer = self.viewer.add_image(
|
720
|
+
data=processed_data,
|
721
|
+
name=new_layer_name,
|
722
|
+
)
|
723
|
+
metadata = selected_layer.metadata.copy()
|
724
|
+
metadata["filter_parameters"] = {self.method_dropdown.value: kwargs.copy()}
|
725
|
+
metadata["mask"] = self.method_dropdown.value
|
726
|
+
metadata["origin_layer"] = selected_layer.name
|
727
|
+
new_layer.metadata = metadata
|
728
|
+
|
729
|
+
if self.method_dropdown.value == "Shape":
|
730
|
+
new_layer.metadata = {}
|
731
|
+
|
732
|
+
# origin_layer = metadata["origin_layer"]
|
733
|
+
# if origin_layer in self.viewer.layers:
|
734
|
+
# origin_layer = self.viewer.layers[origin_layer]
|
735
|
+
# if np.allclose(origin_layer.data.shape, processed_data.shape):
|
736
|
+
# in_mask = np.sum(np.fmax(origin_layer.data * processed_data, 0))
|
737
|
+
# in_mask /= np.sum(np.fmax(origin_layer.data, 0))
|
738
|
+
# in_mask *= 100
|
739
|
+
# self.density_field.value = f"Positive Density in Mask: {in_mask:.2f}%"
|
740
|
+
|
741
|
+
|
742
|
+
class AlignmentWidget(widgets.Container):
|
743
|
+
def __init__(self, viewer):
|
744
|
+
super().__init__(layout="vertical")
|
745
|
+
|
746
|
+
self.viewer = viewer
|
747
|
+
|
748
|
+
align_button = widgets.PushButton(text="Align to axis", enabled=True)
|
749
|
+
self.align_axis = widgets.ComboBox(
|
750
|
+
value=None, nullable=True, choices=self._get_active_layer_dims
|
751
|
+
)
|
752
|
+
self.viewer.layers.selection.events.changed.connect(self._update_align_axis)
|
753
|
+
|
754
|
+
align_button.changed.connect(self._align_with_axis)
|
755
|
+
container = widgets.Container(
|
756
|
+
widgets=[align_button, self.align_axis], layout="horizontal"
|
757
|
+
)
|
758
|
+
self.append(container)
|
759
|
+
|
760
|
+
rot90 = widgets.PushButton(text="Rotate 90", enabled=True)
|
761
|
+
rotneg90 = widgets.PushButton(text="Rotate -90", enabled=True)
|
762
|
+
|
763
|
+
rot90.changed.connect(self._rot90)
|
764
|
+
rotneg90.changed.connect(self._rotneg90)
|
765
|
+
|
766
|
+
container = widgets.Container(widgets=[rot90, rotneg90], layout="horizontal")
|
767
|
+
self.append(container)
|
768
|
+
|
769
|
+
def _rot90(self, swap_axes: bool = False):
|
770
|
+
active_layer = self.viewer.layers.selection.active
|
771
|
+
if active_layer is None:
|
772
|
+
return None
|
773
|
+
elif self.viewer.dims.ndisplay != 2:
|
774
|
+
return None
|
775
|
+
|
776
|
+
align_axis = self.align_axis.value
|
777
|
+
if self.align_axis.value is None:
|
778
|
+
align_axis = self.viewer.dims.order[0]
|
779
|
+
|
780
|
+
axes = [
|
781
|
+
align_axis,
|
782
|
+
*[i for i in range(len(self.viewer.dims.order)) if i != align_axis],
|
783
|
+
][:2]
|
784
|
+
axes = axes[::-1] if swap_axes else axes
|
785
|
+
active_layer.data = np.rot90(active_layer.data, k=1, axes=axes)
|
786
|
+
|
787
|
+
def _rotneg90(self):
|
788
|
+
return self._rot90(swap_axes=True)
|
789
|
+
|
790
|
+
def _get_active_layer_dims(self, *args):
|
791
|
+
active_layer = self.viewer.layers.selection.active
|
792
|
+
if active_layer is None:
|
793
|
+
return ()
|
794
|
+
try:
|
795
|
+
return [i for i in range(active_layer.data.ndim)]
|
796
|
+
except Exception:
|
797
|
+
return ()
|
798
|
+
|
799
|
+
def _update_align_axis(self, *args):
|
800
|
+
self.align_axis.choices = self._get_active_layer_dims()
|
801
|
+
|
802
|
+
def _align_with_axis(self):
|
803
|
+
active_layer = self.viewer.layers.selection.active
|
804
|
+
|
805
|
+
if self.align_axis.value is None:
|
806
|
+
return None
|
807
|
+
|
808
|
+
if active_layer.metadata.get("is_aligned", None) == self.align_axis.value:
|
809
|
+
return None
|
810
|
+
|
811
|
+
alignment_axis = np.zeros(active_layer.data.ndim)
|
812
|
+
alignment_axis[int(self.align_axis.value)] = 1
|
813
|
+
|
814
|
+
coords = np.array(np.where(active_layer.data > 0)).T
|
815
|
+
centered_coords = coords - np.mean(coords, axis=0)
|
816
|
+
cov_matrix = np.cov(centered_coords, rowvar=False)
|
817
|
+
eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
|
818
|
+
principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
|
819
|
+
|
820
|
+
rotation_matrix = align_vectors(principal_eigenvector, alignment_axis)
|
821
|
+
rotated_data, _ = backend.rigid_transform(
|
822
|
+
arr=active_layer.data,
|
823
|
+
rotation_matrix=rotation_matrix,
|
824
|
+
use_geometric_center=False,
|
825
|
+
order=1,
|
826
|
+
)
|
827
|
+
eps = np.finfo(rotated_data.dtype).eps
|
828
|
+
rotated_data[rotated_data < eps] = 0
|
829
|
+
|
830
|
+
active_layer.metadata["is_aligned"] = int(self.align_axis.value)
|
831
|
+
active_layer.data = rotated_data
|
832
|
+
|
833
|
+
|
834
|
+
class ExportWidget(widgets.Container):
|
835
|
+
def __init__(self, viewer):
|
836
|
+
super().__init__(layout="vertical")
|
837
|
+
|
838
|
+
self.viewer = viewer
|
839
|
+
self.selected_filename = ""
|
840
|
+
|
841
|
+
horizontal_container = widgets.Container(layout="horizontal")
|
842
|
+
|
843
|
+
self.gzip_output = widgets.CheckBox(name="gzip", value=False, label="gzip")
|
844
|
+
self.export_button = widgets.PushButton(name="Export", text="Export")
|
845
|
+
self.export_button.clicked.connect(self._get_save_path)
|
846
|
+
|
847
|
+
horizontal_container.append(self.export_button)
|
848
|
+
horizontal_container.append(self.gzip_output)
|
849
|
+
|
850
|
+
self.append(horizontal_container)
|
851
|
+
|
852
|
+
self.export_button.enabled = bool(self.viewer.layers.selection.active)
|
853
|
+
self.viewer.layers.selection.events.active.connect(
|
854
|
+
self._update_export_button_state
|
855
|
+
)
|
856
|
+
|
857
|
+
def _get_save_path(self, event):
|
858
|
+
path, _ = QFileDialog().getSaveFileName(
|
859
|
+
self.native,
|
860
|
+
"Save As...",
|
861
|
+
"",
|
862
|
+
"MRC Files (*.mrc)",
|
863
|
+
)
|
864
|
+
if path:
|
865
|
+
self.selected_filename = path
|
866
|
+
self._export_data()
|
867
|
+
|
868
|
+
def _update_export_button_state(self, event):
|
869
|
+
"""Update the enabled state of the export button based on the active layer."""
|
870
|
+
self.export_button.enabled = bool(self.viewer.layers.selection.active)
|
871
|
+
|
872
|
+
def _export_data(self):
|
873
|
+
selected_layer = self.viewer.layers.selection.active
|
874
|
+
if selected_layer and isinstance(selected_layer, Image):
|
875
|
+
selected_layer.metadata["write_gzip"] = self.gzip_output.value
|
876
|
+
selected_layer.save(path=self.selected_filename)
|
877
|
+
|
878
|
+
|
879
|
+
class PointCloudWidget(widgets.Container):
|
880
|
+
def __init__(self, viewer):
|
881
|
+
super().__init__(layout="vertical")
|
882
|
+
|
883
|
+
self.viewer = viewer
|
884
|
+
self.dataframes = {}
|
885
|
+
self.selected_category = -1
|
886
|
+
|
887
|
+
self.import_button = widgets.PushButton(
|
888
|
+
name="Import", text="Import Point Cloud"
|
889
|
+
)
|
890
|
+
self.import_button.clicked.connect(self._get_load_path)
|
891
|
+
|
892
|
+
self.export_button = widgets.PushButton(
|
893
|
+
name="Export", text="Export Point Cloud"
|
894
|
+
)
|
895
|
+
self.export_button.clicked.connect(self._export_point_cloud)
|
896
|
+
self.export_button.enabled = False
|
897
|
+
|
898
|
+
self.score_filter_container = widgets.Container(
|
899
|
+
name="Score Filter", layout="vertical"
|
900
|
+
)
|
901
|
+
self.score_range_slider = widgets.FloatRangeSlider(
|
902
|
+
min=0,
|
903
|
+
max=100,
|
904
|
+
step=0.1,
|
905
|
+
value=(0, 100),
|
906
|
+
readout=False,
|
907
|
+
)
|
908
|
+
self.score_range_slider.changed.connect(self._update_point_visibility)
|
909
|
+
self.score_filter_container.append(self.score_range_slider)
|
910
|
+
|
911
|
+
self.annotation_container = widgets.Container(name="Label", layout="horizontal")
|
912
|
+
self.positive_button = widgets.PushButton(name="Positive", text="Positive")
|
913
|
+
self.negative_button = widgets.PushButton(name="Negative", text="Negative")
|
914
|
+
self.positive_button.clicked.connect(self._set_positive)
|
915
|
+
self.negative_button.clicked.connect(self._set_negative)
|
916
|
+
self.annotation_container.append(self.positive_button)
|
917
|
+
self.annotation_container.append(self.negative_button)
|
918
|
+
|
919
|
+
self.face_color_select = widgets.ComboBox(
|
920
|
+
name="Color", choices=["Label", "Score"], value=None, nullable=True
|
921
|
+
)
|
922
|
+
self.face_color_select.changed.connect(self._update_face_color_mode)
|
923
|
+
|
924
|
+
self.append(self.import_button)
|
925
|
+
self.append(self.export_button)
|
926
|
+
self.append(self.score_filter_container)
|
927
|
+
self.append(self.annotation_container)
|
928
|
+
self.append(self.face_color_select)
|
929
|
+
|
930
|
+
self.viewer.layers.selection.events.changed.connect(self._update_buttons)
|
931
|
+
|
932
|
+
self.viewer.layers.events.inserted.connect(self._initialize_points_layer)
|
933
|
+
|
934
|
+
def _update_point_visibility(self, event=None):
|
935
|
+
min_percentile, max_percentile = self.score_range_slider.value
|
936
|
+
|
937
|
+
for layer in self.viewer.layers:
|
938
|
+
if not isinstance(layer, napari.layers.Points):
|
939
|
+
continue
|
940
|
+
|
941
|
+
if "score" not in layer.properties:
|
942
|
+
continue
|
943
|
+
|
944
|
+
scores = layer.properties["score"]
|
945
|
+
layer.selected_data = set()
|
946
|
+
|
947
|
+
min_score = np.percentile(scores, min_percentile)
|
948
|
+
max_score = np.percentile(scores, max_percentile)
|
949
|
+
visible_mask = (scores >= min_score) & (scores <= max_score)
|
950
|
+
layer.shown = visible_mask
|
951
|
+
|
952
|
+
def _update_face_color_mode(self, event: str = None):
|
953
|
+
for layer in self.viewer.layers:
|
954
|
+
if not isinstance(layer, napari.layers.Points):
|
955
|
+
continue
|
956
|
+
|
957
|
+
layer.face_color = "white"
|
958
|
+
if event == "Label":
|
959
|
+
if len(layer.properties.get("detail", ())) == 0:
|
960
|
+
continue
|
961
|
+
layer.face_color = "detail"
|
962
|
+
layer.face_color_cycle = {
|
963
|
+
-1: "grey",
|
964
|
+
0: "red",
|
965
|
+
1: "green",
|
966
|
+
}
|
967
|
+
layer.face_color_mode = "cycle"
|
968
|
+
elif event == "Score":
|
969
|
+
if len(layer.properties.get("score_scaled", ())) == 0:
|
970
|
+
continue
|
971
|
+
layer.face_color = "score_scaled"
|
972
|
+
layer.face_colormap = "turbo"
|
973
|
+
layer.face_color_mode = "colormap"
|
974
|
+
|
975
|
+
layer.refresh_colors()
|
976
|
+
|
977
|
+
return None
|
978
|
+
|
979
|
+
def _set_positive(self, event):
|
980
|
+
self.selected_category = 1 if self.selected_category != 1 else -1
|
981
|
+
self._update_annotation_buttons()
|
982
|
+
|
983
|
+
def _set_negative(self, event):
|
984
|
+
self.selected_category = 0 if self.selected_category != 0 else -1
|
985
|
+
self._update_annotation_buttons()
|
986
|
+
|
987
|
+
def _update_annotation_buttons(self):
|
988
|
+
selected_style = "background-color: darkgrey"
|
989
|
+
default_style = "background-color: none"
|
990
|
+
|
991
|
+
self.positive_button.native.setStyleSheet(
|
992
|
+
selected_style if self.selected_category == 1 else default_style
|
993
|
+
)
|
994
|
+
self.negative_button.native.setStyleSheet(
|
995
|
+
selected_style if self.selected_category == 0 else default_style
|
996
|
+
)
|
997
|
+
|
998
|
+
def _initialize_points_layer(self, event):
|
999
|
+
layer = event.value
|
1000
|
+
if not isinstance(layer, napari.layers.Points):
|
1001
|
+
return
|
1002
|
+
if len(layer.properties) == 0:
|
1003
|
+
layer.properties = {"detail": [-1]}
|
1004
|
+
|
1005
|
+
if "detail" not in layer.properties:
|
1006
|
+
layer["detail"] = [-1]
|
1007
|
+
|
1008
|
+
layer.mouse_drag_callbacks.append(self._on_click)
|
1009
|
+
return None
|
1010
|
+
|
1011
|
+
def _on_click(self, layer, event):
|
1012
|
+
if layer.mode == "add":
|
1013
|
+
layer.current_properties["detail"][-1] = self.selected_category
|
1014
|
+
elif layer.mode == "select":
|
1015
|
+
for index in layer.selected_data:
|
1016
|
+
layer.properties["detail"][index] = self.selected_category
|
1017
|
+
|
1018
|
+
# TODO: Check whether current face color is the desired one already
|
1019
|
+
self._update_face_color_mode(self.face_color_select.value)
|
1020
|
+
layer.refresh_colors()
|
1021
|
+
|
1022
|
+
def _update_buttons(self, event):
|
1023
|
+
is_pointcloud = isinstance(
|
1024
|
+
self.viewer.layers.selection.active, napari.layers.Points
|
1025
|
+
)
|
1026
|
+
if self.viewer.layers.selection.active and is_pointcloud:
|
1027
|
+
self.export_button.enabled = True
|
1028
|
+
else:
|
1029
|
+
self.export_button.enabled = False
|
1030
|
+
|
1031
|
+
def _export_point_cloud(self, event):
|
1032
|
+
filename, _ = QFileDialog().getSaveFileName(
|
1033
|
+
self.native,
|
1034
|
+
"Save Point Cloud File...",
|
1035
|
+
"",
|
1036
|
+
"Orientations (*.tsv *.star);; All Files (*)",
|
1037
|
+
)
|
1038
|
+
|
1039
|
+
if not filename:
|
1040
|
+
return None
|
1041
|
+
|
1042
|
+
layer = self.viewer.layers.selection.active
|
1043
|
+
if layer and isinstance(layer, napari.layers.Points):
|
1044
|
+
original_dataframe = self.dataframes.get(
|
1045
|
+
layer.name,
|
1046
|
+
pd.DataFrame(columns=["x", "y", "z", "euler_x", "euler_y", "euler_z"]),
|
1047
|
+
)
|
1048
|
+
export_data = pd.DataFrame(layer.data, columns=["x", "y", "z"])
|
1049
|
+
merged_data = pd.merge(
|
1050
|
+
export_data, original_dataframe, on=["x", "y", "z"], how="left"
|
1051
|
+
)
|
1052
|
+
|
1053
|
+
merged_data["z"] = merged_data["z"].astype(int)
|
1054
|
+
merged_data["y"] = merged_data["y"].astype(int)
|
1055
|
+
merged_data["x"] = merged_data["x"].astype(int)
|
1056
|
+
|
1057
|
+
euler_columns = ["euler_z", "euler_y", "euler_x"]
|
1058
|
+
for col in euler_columns:
|
1059
|
+
if col not in merged_data.columns:
|
1060
|
+
continue
|
1061
|
+
merged_data[col] = merged_data[col].fillna(0)
|
1062
|
+
|
1063
|
+
if "score" in merged_data.columns:
|
1064
|
+
merged_data["score"] = merged_data["score"].fillna(1)
|
1065
|
+
merged_data["detail"] = layer.properties["detail"]
|
1066
|
+
merged_data = merged_data[layer.shown]
|
1067
|
+
|
1068
|
+
translations = np.stack(
|
1069
|
+
(merged_data["x"], merged_data["y"], merged_data["z"]), axis=1
|
1070
|
+
)
|
1071
|
+
rotations = np.stack(
|
1072
|
+
(
|
1073
|
+
merged_data["euler_x"],
|
1074
|
+
merged_data["euler_y"],
|
1075
|
+
merged_data["euler_z"],
|
1076
|
+
),
|
1077
|
+
axis=1,
|
1078
|
+
)
|
1079
|
+
orientations = Orientations(
|
1080
|
+
translations=translations,
|
1081
|
+
rotations=rotations,
|
1082
|
+
scores=merged_data["score"],
|
1083
|
+
details=merged_data["detail"],
|
1084
|
+
)
|
1085
|
+
orientations.to_file(filename)
|
1086
|
+
|
1087
|
+
def _get_load_path(self, event):
|
1088
|
+
filename, _ = QFileDialog().getOpenFileName(
|
1089
|
+
self.native,
|
1090
|
+
"Open Point Cloud File...",
|
1091
|
+
"",
|
1092
|
+
"Orientations (*.tsv *.star);; All Files (*)",
|
1093
|
+
)
|
1094
|
+
if filename:
|
1095
|
+
self._load_point_cloud(filename)
|
1096
|
+
|
1097
|
+
def _load_point_cloud(self, filename):
|
1098
|
+
orientations = Orientations.from_file(filename)
|
1099
|
+
|
1100
|
+
data = {
|
1101
|
+
"x": orientations.translations[:, 0],
|
1102
|
+
"y": orientations.translations[:, 1],
|
1103
|
+
"z": orientations.translations[:, 2],
|
1104
|
+
"euler_x": orientations.rotations[:, 0],
|
1105
|
+
"euler_y": orientations.rotations[:, 1],
|
1106
|
+
"euler_z": orientations.rotations[:, 2],
|
1107
|
+
"score": orientations.scores,
|
1108
|
+
"detail": orientations.details,
|
1109
|
+
}
|
1110
|
+
|
1111
|
+
dataframe = pd.DataFrame.from_dict(data)
|
1112
|
+
layer_name = filename.split("/")[-1]
|
1113
|
+
|
1114
|
+
if "score" not in dataframe.columns:
|
1115
|
+
dataframe["score"] = 1
|
1116
|
+
|
1117
|
+
if "detail" not in dataframe.columns:
|
1118
|
+
dataframe["detail"] = -1
|
1119
|
+
|
1120
|
+
point_properties = {
|
1121
|
+
"score": np.array(dataframe["score"].values),
|
1122
|
+
"detail": np.array(dataframe["detail"].values),
|
1123
|
+
}
|
1124
|
+
point_properties["score_scaled"] = np.log1p(
|
1125
|
+
point_properties["score"] - point_properties["score"].min()
|
1126
|
+
)
|
1127
|
+
self.score_range_slider.value = (0, 100)
|
1128
|
+
self.viewer.add_points(
|
1129
|
+
dataframe[["x", "y", "z"]].values,
|
1130
|
+
size=10,
|
1131
|
+
properties=point_properties,
|
1132
|
+
name=layer_name,
|
1133
|
+
)
|
1134
|
+
self.dataframes[layer_name] = dataframe
|
1135
|
+
|
1136
|
+
|
1137
|
+
class MatchingWidget(widgets.Container):
|
1138
|
+
def __init__(self, viewer):
|
1139
|
+
super().__init__(layout="vertical")
|
1140
|
+
|
1141
|
+
self.viewer = viewer
|
1142
|
+
self.dataframes = {}
|
1143
|
+
|
1144
|
+
self.import_button = widgets.PushButton(name="Import", text="Import Pickle")
|
1145
|
+
self.import_button.clicked.connect(self._get_load_path)
|
1146
|
+
|
1147
|
+
self.append(self.import_button)
|
1148
|
+
|
1149
|
+
def _get_load_path(self, event):
|
1150
|
+
filename, _ = QFileDialog.getOpenFileName(
|
1151
|
+
self.native,
|
1152
|
+
"Open Pickle File...",
|
1153
|
+
"",
|
1154
|
+
"Pickle Files (*.pickle);;All Files (*)",
|
1155
|
+
)
|
1156
|
+
if filename:
|
1157
|
+
self._load_data(filename)
|
1158
|
+
|
1159
|
+
def _load_data(self, filename):
|
1160
|
+
data = load_pickle(filename)
|
1161
|
+
|
1162
|
+
fname = basename(filename).replace(".pickle", "")
|
1163
|
+
if data[0].ndim == data[2].ndim:
|
1164
|
+
metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
|
1165
|
+
_ = self.viewer.add_image(
|
1166
|
+
data=data[2],
|
1167
|
+
name=f"{fname}_rotations",
|
1168
|
+
colormap="orange",
|
1169
|
+
metadata=metadata,
|
1170
|
+
)
|
1171
|
+
_ = self.viewer.add_image(
|
1172
|
+
data=data[0],
|
1173
|
+
name=f"{fname}_scores",
|
1174
|
+
colormap="turbo",
|
1175
|
+
metadata=metadata,
|
1176
|
+
)
|
1177
|
+
return None
|
1178
|
+
detail = np.zeros_like(data[2])
|
1179
|
+
else:
|
1180
|
+
detail = data[3]
|
1181
|
+
|
1182
|
+
point_properties = {"score": data[2], "detail": detail}
|
1183
|
+
point_properties["score_scaled"] = np.log1p(
|
1184
|
+
point_properties["score"] - point_properties["score"].min()
|
1185
|
+
)
|
1186
|
+
layer_name = f"{fname}_candidates"
|
1187
|
+
self.viewer.add_points(
|
1188
|
+
data[0],
|
1189
|
+
size=10,
|
1190
|
+
properties=point_properties,
|
1191
|
+
name=layer_name,
|
1192
|
+
)
|
1193
|
+
|
1194
|
+
|
1195
|
+
def main():
|
1196
|
+
viewer = napari.Viewer()
|
1197
|
+
|
1198
|
+
filter_widget = FilterWidget(preprocessor, viewer)
|
1199
|
+
mask_widget = MaskWidget(viewer)
|
1200
|
+
export_widget = ExportWidget(viewer)
|
1201
|
+
point_cloud = PointCloudWidget(viewer)
|
1202
|
+
matching_widget = MatchingWidget(viewer)
|
1203
|
+
alignment_widget = AlignmentWidget(viewer)
|
1204
|
+
|
1205
|
+
viewer.window.add_dock_widget(widget=filter_widget, name="Preprocess", area="right")
|
1206
|
+
viewer.window.add_dock_widget(
|
1207
|
+
widget=alignment_widget, name="Alignment", area="right"
|
1208
|
+
)
|
1209
|
+
viewer.window.add_dock_widget(widget=mask_widget, name="Mask", area="right")
|
1210
|
+
viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
|
1211
|
+
viewer.window.add_dock_widget(widget=point_cloud, name="PointCloud", area="left")
|
1212
|
+
viewer.window.add_dock_widget(widget=matching_widget, name="Matching", area="left")
|
1213
|
+
|
1214
|
+
napari.run()
|
1215
|
+
|
1216
|
+
|
1217
|
+
def parse_args():
|
1218
|
+
parser = argparse.ArgumentParser(
|
1219
|
+
description="GUI for preparing and analyzing template matching runs."
|
1220
|
+
)
|
1221
|
+
args = parser.parse_args()
|
1222
|
+
return args
|
1223
|
+
|
1224
|
+
|
1225
|
+
if __name__ == "__main__":
|
1226
|
+
parse_args()
|
1227
|
+
main()
|