pytme 0.1.5__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.1.5.data/scripts/estimate_ram_usage.py +81 -0
- pytme-0.1.5.data/scripts/match_template.py +744 -0
- pytme-0.1.5.data/scripts/postprocess.py +279 -0
- pytme-0.1.5.data/scripts/preprocess.py +93 -0
- pytme-0.1.5.data/scripts/preprocessor_gui.py +729 -0
- pytme-0.1.5.dist-info/LICENSE +153 -0
- pytme-0.1.5.dist-info/METADATA +69 -0
- pytme-0.1.5.dist-info/RECORD +63 -0
- pytme-0.1.5.dist-info/WHEEL +5 -0
- pytme-0.1.5.dist-info/entry_points.txt +6 -0
- pytme-0.1.5.dist-info/top_level.txt +2 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +81 -0
- scripts/match_template.py +744 -0
- scripts/match_template_devel.py +788 -0
- scripts/postprocess.py +279 -0
- scripts/preprocess.py +93 -0
- scripts/preprocessor_gui.py +729 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer.py +1144 -0
- tme/backends/__init__.py +134 -0
- tme/backends/cupy_backend.py +309 -0
- tme/backends/matching_backend.py +1154 -0
- tme/backends/npfftw_backend.py +763 -0
- tme/backends/pytorch_backend.py +526 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2314 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/helpers.py +881 -0
- tme/matching_data.py +377 -0
- tme/matching_exhaustive.py +1553 -0
- tme/matching_memory.py +382 -0
- tme/matching_optimization.py +1123 -0
- tme/matching_utils.py +1180 -0
- tme/parser.py +429 -0
- tme/preprocessor.py +1291 -0
- tme/scoring.py +866 -0
- tme/structure.py +1428 -0
- tme/types.py +10 -0
@@ -0,0 +1,729 @@
|
|
1
|
+
#!python
|
2
|
+
""" Simplify picking adequate filtering and masking parameters using a GUI.
|
3
|
+
Exposes tme.preprocessor.Preprocessor and tme.fitter_utils member functions
|
4
|
+
to achieve this aim.
|
5
|
+
|
6
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
7
|
+
|
8
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
9
|
+
"""
|
10
|
+
import inspect
|
11
|
+
from typing import Tuple, Callable, List
|
12
|
+
from typing_extensions import Annotated
|
13
|
+
|
14
|
+
import numpy as np
|
15
|
+
import pandas as pd
|
16
|
+
import napari
|
17
|
+
from napari.layers import Image
|
18
|
+
from napari.utils.events import EventedList
|
19
|
+
|
20
|
+
from magicgui import widgets
|
21
|
+
from qtpy.QtWidgets import QFileDialog
|
22
|
+
from numpy.typing import NDArray
|
23
|
+
|
24
|
+
from tme import Preprocessor, Density
|
25
|
+
from tme.matching_utils import create_mask
|
26
|
+
|
27
|
+
preprocessor = Preprocessor()
|
28
|
+
SLIDER_MIN, SLIDER_MAX = 0, 25
|
29
|
+
|
30
|
+
|
31
|
+
def gaussian_filter(template, sigma: float, **kwargs: dict):
|
32
|
+
return preprocessor.gaussian_filter(template=template, sigma=sigma, **kwargs)
|
33
|
+
|
34
|
+
|
35
|
+
def bandpass_filter(
|
36
|
+
template,
|
37
|
+
minimum_frequency: float,
|
38
|
+
maximum_frequency: float,
|
39
|
+
gaussian_sigma: float,
|
40
|
+
**kwargs: dict,
|
41
|
+
):
|
42
|
+
return preprocessor.bandpass_filter(
|
43
|
+
template=template,
|
44
|
+
minimum_frequency=minimum_frequency,
|
45
|
+
maximum_frequency=maximum_frequency,
|
46
|
+
sampling_rate=1,
|
47
|
+
gaussian_sigma=gaussian_sigma,
|
48
|
+
**kwargs,
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
def difference_of_gaussian_filter(
|
53
|
+
template, sigmas: Tuple[float, float], **kwargs: dict
|
54
|
+
):
|
55
|
+
low_sigma, high_sigma = sigmas
|
56
|
+
return preprocessor.difference_of_gaussian_filter(
|
57
|
+
template=template, low_sigma=low_sigma, high_sigma=high_sigma, **kwargs
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def edge_gaussian_filter(
|
62
|
+
template,
|
63
|
+
sigma: float,
|
64
|
+
edge_algorithm: Annotated[
|
65
|
+
str,
|
66
|
+
{"choices": ["sobel", "prewitt", "laplace", "gaussian", "gaussian_laplace"]},
|
67
|
+
],
|
68
|
+
reverse: bool = False,
|
69
|
+
**kwargs: dict,
|
70
|
+
):
|
71
|
+
return preprocessor.edge_gaussian_filter(
|
72
|
+
template=template,
|
73
|
+
sigma=sigma,
|
74
|
+
reverse=reverse,
|
75
|
+
edge_algorithm=edge_algorithm,
|
76
|
+
)
|
77
|
+
|
78
|
+
|
79
|
+
def local_gaussian_filter(
|
80
|
+
template,
|
81
|
+
lbd: float,
|
82
|
+
sigma_range: Tuple[float, float],
|
83
|
+
gaussian_sigma: float,
|
84
|
+
reverse: bool = False,
|
85
|
+
**kwargs: dict,
|
86
|
+
):
|
87
|
+
return preprocessor.local_gaussian_filter(
|
88
|
+
template=template,
|
89
|
+
lbd=lbd,
|
90
|
+
sigma_range=sigma_range,
|
91
|
+
gaussian_sigma=gaussian_sigma,
|
92
|
+
)
|
93
|
+
|
94
|
+
|
95
|
+
def ntree(
|
96
|
+
template,
|
97
|
+
sigma_range: Tuple[float, float],
|
98
|
+
**kwargs: dict,
|
99
|
+
):
|
100
|
+
return preprocessor.ntree_filter(template=template, sigma_range=sigma_range)
|
101
|
+
|
102
|
+
|
103
|
+
def mean(
|
104
|
+
template,
|
105
|
+
width: int,
|
106
|
+
**kwargs: dict,
|
107
|
+
):
|
108
|
+
return preprocessor.mean_filter(template=template, width=width)
|
109
|
+
|
110
|
+
|
111
|
+
def wedge(
|
112
|
+
template: NDArray,
|
113
|
+
tilt_start: float,
|
114
|
+
tilt_stop: float,
|
115
|
+
gaussian_sigma: float,
|
116
|
+
tilt_axis: int = 1,
|
117
|
+
infinite_plane: bool = True,
|
118
|
+
extrude_plane: bool = True,
|
119
|
+
):
|
120
|
+
template_ft = np.fft.rfftn(template)
|
121
|
+
wedge_mask = preprocessor.continuous_wedge_mask(
|
122
|
+
start_tilt=tilt_start,
|
123
|
+
stop_tilt=tilt_stop,
|
124
|
+
tilt_axis=tilt_axis,
|
125
|
+
shape=template.shape,
|
126
|
+
sigma=gaussian_sigma,
|
127
|
+
omit_negative_frequencies=True,
|
128
|
+
infinite_plane=infinite_plane,
|
129
|
+
extrude_plane=extrude_plane,
|
130
|
+
)
|
131
|
+
np.multiply(template_ft, wedge_mask, out=template_ft)
|
132
|
+
template = np.real(np.fft.irfftn(template_ft))
|
133
|
+
return template
|
134
|
+
|
135
|
+
|
136
|
+
def widgets_from_function(function: Callable, exclude_params: List = ["self"]):
|
137
|
+
"""
|
138
|
+
Creates list of magicui widgets by inspecting function typing ann
|
139
|
+
"""
|
140
|
+
ret = []
|
141
|
+
for name, param in inspect.signature(function).parameters.items():
|
142
|
+
if name in exclude_params:
|
143
|
+
continue
|
144
|
+
|
145
|
+
if param.annotation is float:
|
146
|
+
widget = widgets.FloatSpinBox(
|
147
|
+
name=name,
|
148
|
+
value=param.default if param.default != inspect._empty else 0,
|
149
|
+
min=SLIDER_MIN,
|
150
|
+
step=0.5,
|
151
|
+
)
|
152
|
+
elif param.annotation == Tuple[float, float]:
|
153
|
+
widget = widgets.FloatRangeSlider(
|
154
|
+
name=param.name,
|
155
|
+
value=param.default
|
156
|
+
if param.default != inspect._empty
|
157
|
+
else (0.0, SLIDER_MAX / 2),
|
158
|
+
min=SLIDER_MIN,
|
159
|
+
max=SLIDER_MAX,
|
160
|
+
)
|
161
|
+
elif param.annotation is int:
|
162
|
+
widget = widgets.SpinBox(
|
163
|
+
name=name,
|
164
|
+
value=param.default if param.default != inspect._empty else 0,
|
165
|
+
)
|
166
|
+
elif param.annotation is bool:
|
167
|
+
widget = widgets.CheckBox(
|
168
|
+
name=name,
|
169
|
+
value=param.default if param.default != inspect._empty else False,
|
170
|
+
)
|
171
|
+
elif hasattr(param.annotation, "__metadata__"):
|
172
|
+
metadata = param.annotation.__metadata__[0]
|
173
|
+
if "choices" in metadata:
|
174
|
+
widget = widgets.ComboBox(
|
175
|
+
name=param.name,
|
176
|
+
choices=metadata["choices"],
|
177
|
+
value=param.default
|
178
|
+
if param.default != inspect._empty
|
179
|
+
else metadata["choices"][0],
|
180
|
+
)
|
181
|
+
else:
|
182
|
+
continue
|
183
|
+
ret.append(widget)
|
184
|
+
return ret
|
185
|
+
|
186
|
+
|
187
|
+
WRAPPED_FUNCTIONS = {
|
188
|
+
"gaussian_filter": gaussian_filter,
|
189
|
+
"bandpass_filter": bandpass_filter,
|
190
|
+
"edge_gaussian_filter": edge_gaussian_filter,
|
191
|
+
"ntree_filter": ntree,
|
192
|
+
"local_gaussian_filter": local_gaussian_filter,
|
193
|
+
"difference_of_gaussian_filter": difference_of_gaussian_filter,
|
194
|
+
"mean_filter": mean,
|
195
|
+
"continuous_wedge_mask": wedge,
|
196
|
+
}
|
197
|
+
|
198
|
+
EXCLUDED_FUNCTIONS = [
|
199
|
+
"apply_method",
|
200
|
+
"method_to_id",
|
201
|
+
"wedge_mask",
|
202
|
+
"fourier_crop",
|
203
|
+
"fourier_uncrop",
|
204
|
+
"interpolate_box",
|
205
|
+
"molmap",
|
206
|
+
"local_gaussian_alignment_filter",
|
207
|
+
# "continuous_wedge_mask",
|
208
|
+
"wedge_mask",
|
209
|
+
"bandpass_mask",
|
210
|
+
]
|
211
|
+
|
212
|
+
|
213
|
+
class FilterWidget(widgets.Container):
|
214
|
+
def __init__(self, preprocessor, viewer):
|
215
|
+
super().__init__(layout="vertical")
|
216
|
+
|
217
|
+
self.preprocessor = preprocessor
|
218
|
+
self.viewer = viewer
|
219
|
+
self.name_mapping = {}
|
220
|
+
self.action_widgets = []
|
221
|
+
|
222
|
+
self.layer_dropdown = widgets.ComboBox(
|
223
|
+
name="Target Layer", choices=self._get_layer_names()
|
224
|
+
)
|
225
|
+
self.append(self.layer_dropdown)
|
226
|
+
self.viewer.layers.events.inserted.connect(self._update_layer_dropdown)
|
227
|
+
self.viewer.layers.events.removed.connect(self._update_layer_dropdown)
|
228
|
+
|
229
|
+
self.method_dropdown = widgets.ComboBox(
|
230
|
+
name="Choose Filter", choices=self._get_method_names()
|
231
|
+
)
|
232
|
+
self.method_dropdown.changed.connect(self._on_method_changed)
|
233
|
+
self.append(self.method_dropdown)
|
234
|
+
|
235
|
+
self.apply_btn = widgets.PushButton(text="Apply Filter", enabled=False)
|
236
|
+
self.apply_btn.changed.connect(self._action)
|
237
|
+
self.append(self.apply_btn)
|
238
|
+
|
239
|
+
# Create GUI for initially selected filtering method
|
240
|
+
self._on_method_changed(None)
|
241
|
+
|
242
|
+
def _get_method_names(self):
|
243
|
+
method_names = [
|
244
|
+
name
|
245
|
+
for name, member in inspect.getmembers(self.preprocessor, inspect.ismethod)
|
246
|
+
if not name.startswith("_") and name not in EXCLUDED_FUNCTIONS
|
247
|
+
]
|
248
|
+
|
249
|
+
sanitized_names = [self._sanitize_name(name) for name in method_names]
|
250
|
+
self.name_mapping.update(dict(zip(sanitized_names, method_names)))
|
251
|
+
|
252
|
+
return sanitized_names
|
253
|
+
|
254
|
+
def _sanitize_name(self, name: str) -> str:
|
255
|
+
# Replace underscores with spaces and capitalize each word
|
256
|
+
removes = ["blur", "filter"]
|
257
|
+
for remove in removes:
|
258
|
+
name = name.replace(remove, "")
|
259
|
+
return name.strip().replace("_", " ").title()
|
260
|
+
|
261
|
+
def _desanitize_name(self, name: str) -> str:
|
262
|
+
name = name.lower().strip()
|
263
|
+
for function_name, _ in inspect.getmembers(self.preprocessor, inspect.ismethod):
|
264
|
+
if function_name.startswith(name):
|
265
|
+
return function_name
|
266
|
+
return name
|
267
|
+
|
268
|
+
def _get_function(self, name: str):
|
269
|
+
function = WRAPPED_FUNCTIONS.get(name, None)
|
270
|
+
if not function:
|
271
|
+
function = getattr(self.preprocessor, name, None)
|
272
|
+
return function
|
273
|
+
|
274
|
+
def _on_method_changed(self, event=None):
|
275
|
+
# Clear previous parameter widgets
|
276
|
+
for widget in self.action_widgets:
|
277
|
+
self.remove(widget)
|
278
|
+
self.action_widgets.clear()
|
279
|
+
|
280
|
+
function_name = self.name_mapping.get(self.method_dropdown.value)
|
281
|
+
function = self._get_function(function_name)
|
282
|
+
|
283
|
+
widgets = widgets_from_function(function, exclude_params=["self", "template"])
|
284
|
+
for widget in widgets:
|
285
|
+
self.action_widgets.append(widget)
|
286
|
+
self.insert(-1, widget)
|
287
|
+
|
288
|
+
def _update_layer_dropdown(self, event: EventedList):
|
289
|
+
"""Update the dropdown menu when layers change."""
|
290
|
+
self.layer_dropdown.choices = self._get_layer_names()
|
291
|
+
self.apply_btn.enabled = bool(self.viewer.layers)
|
292
|
+
|
293
|
+
def _get_layer_names(self):
|
294
|
+
"""Return list of layer names in the viewer."""
|
295
|
+
return sorted([layer.name for layer in self.viewer.layers])
|
296
|
+
|
297
|
+
def _action(self, event):
|
298
|
+
selected_layer = self.viewer.layers[self.layer_dropdown.value]
|
299
|
+
selected_layer_metadata = selected_layer.metadata.copy()
|
300
|
+
kwargs = {widget.name: widget.value for widget in self.action_widgets}
|
301
|
+
|
302
|
+
function_name = self.name_mapping.get(self.method_dropdown.value)
|
303
|
+
function = self._get_function(function_name)
|
304
|
+
|
305
|
+
processed_data = function(selected_layer.data, **kwargs)
|
306
|
+
|
307
|
+
new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
|
308
|
+
|
309
|
+
if new_layer_name in self.viewer.layers:
|
310
|
+
selected_layer = self.viewer.layers[new_layer_name]
|
311
|
+
|
312
|
+
filter_name = self._desanitize_name(self.method_dropdown.value)
|
313
|
+
used_filter = selected_layer.metadata.get("used_filter", False)
|
314
|
+
if used_filter == filter_name:
|
315
|
+
selected_layer.data = processed_data
|
316
|
+
else:
|
317
|
+
new_layer = self.viewer.add_image(
|
318
|
+
data=processed_data,
|
319
|
+
name=new_layer_name,
|
320
|
+
)
|
321
|
+
metadata = selected_layer_metadata.copy()
|
322
|
+
if "filter_parameters" not in metadata:
|
323
|
+
metadata["filter_parameters"] = []
|
324
|
+
metadata["filter_parameters"].append({filter_name: kwargs.copy()})
|
325
|
+
metadata["used_filter"] = filter_name
|
326
|
+
new_layer.metadata = metadata
|
327
|
+
|
328
|
+
|
329
|
+
def sphere_mask(
|
330
|
+
template: NDArray, center_x: float, center_y: float, center_z: float, radius: float
|
331
|
+
):
|
332
|
+
return create_mask(
|
333
|
+
mask_type="ellipse",
|
334
|
+
shape=template.shape,
|
335
|
+
center=(center_x, center_y, center_z),
|
336
|
+
radius=radius,
|
337
|
+
)
|
338
|
+
|
339
|
+
|
340
|
+
def ellipsod_mask(
|
341
|
+
template: NDArray,
|
342
|
+
center_x: float,
|
343
|
+
center_y: float,
|
344
|
+
center_z: float,
|
345
|
+
radius_x: float,
|
346
|
+
radius_y: float,
|
347
|
+
radius_z: float,
|
348
|
+
):
|
349
|
+
return create_mask(
|
350
|
+
mask_type="ellipse",
|
351
|
+
shape=template.shape,
|
352
|
+
center=(center_x, center_y, center_z),
|
353
|
+
radius=(radius_x, radius_y, radius_z),
|
354
|
+
)
|
355
|
+
|
356
|
+
|
357
|
+
def box_mask(
|
358
|
+
template: NDArray,
|
359
|
+
center_x: float,
|
360
|
+
center_y: float,
|
361
|
+
center_z: float,
|
362
|
+
height_x: int,
|
363
|
+
height_y: int,
|
364
|
+
height_z: int,
|
365
|
+
):
|
366
|
+
return create_mask(
|
367
|
+
mask_type="box",
|
368
|
+
shape=template.shape,
|
369
|
+
center=(center_x, center_y, center_z),
|
370
|
+
height=(height_x, height_y, height_z),
|
371
|
+
)
|
372
|
+
|
373
|
+
|
374
|
+
def tube_mask(
|
375
|
+
template: NDArray,
|
376
|
+
symmetry_axis: int,
|
377
|
+
center_x: float,
|
378
|
+
center_y: float,
|
379
|
+
center_z: float,
|
380
|
+
inner_radius: float,
|
381
|
+
outer_radius: float,
|
382
|
+
height: int,
|
383
|
+
):
|
384
|
+
return create_mask(
|
385
|
+
mask_type="tube",
|
386
|
+
shape=template.shape,
|
387
|
+
symmetry_axis=symmetry_axis,
|
388
|
+
base_center=(center_x, center_y, center_z),
|
389
|
+
inner_radius=inner_radius,
|
390
|
+
outer_radius=outer_radius,
|
391
|
+
height=height,
|
392
|
+
)
|
393
|
+
|
394
|
+
|
395
|
+
def wedge_mask(
|
396
|
+
template: NDArray,
|
397
|
+
tilt_start: float,
|
398
|
+
tilt_stop: float,
|
399
|
+
gaussian_sigma: float,
|
400
|
+
tilt_axis: int = 1,
|
401
|
+
omit_negative_frequencies: bool = True,
|
402
|
+
extrude_plane: bool = True,
|
403
|
+
infinite_plane: bool = True,
|
404
|
+
):
|
405
|
+
wedge_mask = preprocessor.continuous_wedge_mask(
|
406
|
+
start_tilt=tilt_start,
|
407
|
+
stop_tilt=tilt_stop,
|
408
|
+
tilt_axis=tilt_axis,
|
409
|
+
shape=template.shape,
|
410
|
+
sigma=gaussian_sigma,
|
411
|
+
omit_negative_frequencies=omit_negative_frequencies,
|
412
|
+
extrude_plane=extrude_plane,
|
413
|
+
infinite_plane=infinite_plane,
|
414
|
+
)
|
415
|
+
wedge_mask = np.fft.fftshift(wedge_mask)
|
416
|
+
return wedge_mask
|
417
|
+
|
418
|
+
def threshold_mask(
|
419
|
+
template: NDArray,
|
420
|
+
standard_deviation : float = 5.0,
|
421
|
+
invert : bool = False
|
422
|
+
):
|
423
|
+
template_mean = template.mean()
|
424
|
+
template_deviation = standard_deviation * template.std()
|
425
|
+
upper = template_mean + template_deviation
|
426
|
+
lower = template_mean - template_deviation
|
427
|
+
mask = np.logical_and(template > lower, template < upper)
|
428
|
+
if invert:
|
429
|
+
np.invert(mask, out = mask)
|
430
|
+
|
431
|
+
return mask
|
432
|
+
|
433
|
+
|
434
|
+
|
435
|
+
class MaskWidget(widgets.Container):
|
436
|
+
def __init__(self, viewer):
|
437
|
+
super().__init__(layout="vertical")
|
438
|
+
|
439
|
+
self.viewer = viewer
|
440
|
+
self.action_widgets = []
|
441
|
+
|
442
|
+
self.action_button = widgets.PushButton(text="Create mask", enabled=False)
|
443
|
+
self.action_button.changed.connect(self._action)
|
444
|
+
|
445
|
+
self.methods = {
|
446
|
+
"Sphere": sphere_mask,
|
447
|
+
"Ellipsoid": ellipsod_mask,
|
448
|
+
"Tube": tube_mask,
|
449
|
+
"Box": box_mask,
|
450
|
+
"Wedge": wedge_mask,
|
451
|
+
"Threshold" : threshold_mask
|
452
|
+
}
|
453
|
+
|
454
|
+
self.method_dropdown = widgets.ComboBox(
|
455
|
+
name="Choose Mask", choices=list(self.methods.keys())
|
456
|
+
)
|
457
|
+
self.method_dropdown.changed.connect(self._on_method_changed)
|
458
|
+
|
459
|
+
self.adapt_button = widgets.PushButton(text="Adapt to layer", enabled=False)
|
460
|
+
self.adapt_button.changed.connect(self._update_initial_values)
|
461
|
+
|
462
|
+
self.viewer.layers.selection.events.active.connect(
|
463
|
+
self._update_action_button_state
|
464
|
+
)
|
465
|
+
|
466
|
+
self.align_button = widgets.PushButton(text="Align to axis", enabled=False)
|
467
|
+
self.align_button.changed.connect(self._align_with_axis)
|
468
|
+
self.density_field = widgets.Label()
|
469
|
+
# self.density_field.value = f"Positive Density in Mask: {0:.2f}%"
|
470
|
+
|
471
|
+
self.append(self.method_dropdown)
|
472
|
+
self.append(self.adapt_button)
|
473
|
+
self.append(self.align_button)
|
474
|
+
self.append(self.action_button)
|
475
|
+
self.append(self.density_field)
|
476
|
+
|
477
|
+
# Create GUI for initially selected filtering method
|
478
|
+
self._on_method_changed(None)
|
479
|
+
|
480
|
+
def _update_action_button_state(self, event):
|
481
|
+
self.align_button.enabled = bool(self.viewer.layers.selection.active)
|
482
|
+
self.action_button.enabled = bool(self.viewer.layers.selection.active)
|
483
|
+
self.adapt_button.enabled = bool(self.viewer.layers.selection.active)
|
484
|
+
|
485
|
+
def _align_with_axis(self):
|
486
|
+
active_layer = self.viewer.layers.selection.active
|
487
|
+
|
488
|
+
if active_layer.metadata.get("is_aligned", False):
|
489
|
+
return
|
490
|
+
|
491
|
+
coords = np.array(np.where(active_layer.data > 0)).T
|
492
|
+
centered_coords = coords - np.mean(coords, axis=0)
|
493
|
+
cov_matrix = np.cov(centered_coords, rowvar=False)
|
494
|
+
|
495
|
+
eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
|
496
|
+
principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
|
497
|
+
|
498
|
+
rotation_axis = np.cross(principal_eigenvector, [1, 0, 0])
|
499
|
+
rotation_angle = np.arccos(np.dot(principal_eigenvector, [1, 0, 0]))
|
500
|
+
k = rotation_axis / np.linalg.norm(rotation_axis)
|
501
|
+
K = np.array([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]])
|
502
|
+
rotation_matrix = np.eye(3)
|
503
|
+
rotation_matrix += np.sin(rotation_angle) * K
|
504
|
+
rotation_matrix += (1 - np.cos(rotation_angle)) * np.dot(K, K)
|
505
|
+
|
506
|
+
rotated_data = Density.rotate_array(
|
507
|
+
arr=active_layer.data,
|
508
|
+
rotation_matrix=rotation_matrix,
|
509
|
+
use_geometric_center=False,
|
510
|
+
)
|
511
|
+
eps = np.finfo(rotated_data.dtype).eps
|
512
|
+
rotated_data[rotated_data < eps] = 0
|
513
|
+
|
514
|
+
active_layer.metadata["is_aligned"] = True
|
515
|
+
active_layer.data = rotated_data
|
516
|
+
|
517
|
+
def _update_initial_values(self, event=None):
|
518
|
+
active_layer = self.viewer.layers.selection.active
|
519
|
+
center_of_mass = Density.center_of_mass(np.abs(active_layer.data), 0)
|
520
|
+
coordinates = np.array(np.where(active_layer.data > 0))
|
521
|
+
coordinates_min = coordinates.min(axis=1)
|
522
|
+
coordinates_max = coordinates.max(axis=1)
|
523
|
+
coordinates_heights = coordinates_max - coordinates_min
|
524
|
+
coordinate_radius = np.divide(coordinates_heights, 2)
|
525
|
+
center_of_mass = coordinate_radius + coordinates_min
|
526
|
+
|
527
|
+
defaults = dict(zip(["center_x", "center_y", "center_z"], center_of_mass))
|
528
|
+
defaults.update(
|
529
|
+
dict(zip(["radius_x", "radius_y", "radius_z"], coordinate_radius))
|
530
|
+
)
|
531
|
+
defaults.update(
|
532
|
+
dict(zip(["height_x", "height_y", "height_z"], coordinates_heights))
|
533
|
+
)
|
534
|
+
|
535
|
+
defaults["radius"] = np.min(coordinate_radius)
|
536
|
+
defaults["inner_radius"] = np.min(coordinate_radius)
|
537
|
+
defaults["outer_radius"] = np.max(coordinate_radius)
|
538
|
+
defaults["height"] = defaults["radius"]
|
539
|
+
|
540
|
+
for widget in self.action_widgets:
|
541
|
+
if widget.name in defaults:
|
542
|
+
widget.value = defaults[widget.name]
|
543
|
+
|
544
|
+
def _on_method_changed(self, event=None):
|
545
|
+
for widget in self.action_widgets:
|
546
|
+
self.remove(widget)
|
547
|
+
self.action_widgets.clear()
|
548
|
+
|
549
|
+
function = self.methods.get(self.method_dropdown.value)
|
550
|
+
widgets = widgets_from_function(function)
|
551
|
+
for widget in widgets:
|
552
|
+
self.action_widgets.append(widget)
|
553
|
+
self.insert(1, widget)
|
554
|
+
|
555
|
+
def _action(self):
|
556
|
+
function = self.methods.get(self.method_dropdown.value)
|
557
|
+
|
558
|
+
selected_layer = self.viewer.layers.selection.active
|
559
|
+
kwargs = {widget.name: widget.value for widget in self.action_widgets}
|
560
|
+
processed_data = function(template=selected_layer.data, **kwargs)
|
561
|
+
new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
|
562
|
+
|
563
|
+
if new_layer_name in self.viewer.layers:
|
564
|
+
selected_layer = self.viewer.layers[new_layer_name]
|
565
|
+
|
566
|
+
processed_data = processed_data.astype(np.float32)
|
567
|
+
metadata = selected_layer.metadata
|
568
|
+
mask = metadata.get("mask", False)
|
569
|
+
if mask == self.method_dropdown.value:
|
570
|
+
selected_layer.data = processed_data
|
571
|
+
else:
|
572
|
+
new_layer = self.viewer.add_image(
|
573
|
+
data=processed_data,
|
574
|
+
name=new_layer_name,
|
575
|
+
)
|
576
|
+
metadata = selected_layer.metadata.copy()
|
577
|
+
metadata["filter_parameters"] = {self.method_dropdown.value: kwargs.copy()}
|
578
|
+
metadata["mask"] = self.method_dropdown.value
|
579
|
+
metadata["origin_layer"] = selected_layer.name
|
580
|
+
new_layer.metadata = metadata
|
581
|
+
|
582
|
+
origin_layer = metadata["origin_layer"]
|
583
|
+
if origin_layer in self.viewer.layers:
|
584
|
+
origin_layer = self.viewer.layers[origin_layer]
|
585
|
+
if np.allclose(origin_layer.data.shape, processed_data.shape):
|
586
|
+
in_mask = np.sum(np.fmax(origin_layer.data * processed_data, 0))
|
587
|
+
in_mask /= np.sum(np.fmax(origin_layer.data, 0))
|
588
|
+
in_mask *= 100
|
589
|
+
self.density_field.value = f"Positive Density in Mask: {in_mask:.2f}%"
|
590
|
+
|
591
|
+
|
592
|
+
class ExportWidget(widgets.Container):
|
593
|
+
def __init__(self, viewer):
|
594
|
+
super().__init__(layout="vertical")
|
595
|
+
|
596
|
+
self.viewer = viewer
|
597
|
+
self.selected_filename = ""
|
598
|
+
|
599
|
+
horizontal_container = widgets.Container(layout="horizontal")
|
600
|
+
|
601
|
+
self.gzip_output = widgets.CheckBox(name="gzip", value=False, label="gzip")
|
602
|
+
self.export_button = widgets.PushButton(name="Export", text="Export")
|
603
|
+
self.export_button.clicked.connect(self._get_save_path)
|
604
|
+
|
605
|
+
horizontal_container.append(self.export_button)
|
606
|
+
horizontal_container.append(self.gzip_output)
|
607
|
+
|
608
|
+
self.append(horizontal_container)
|
609
|
+
|
610
|
+
self.export_button.enabled = bool(self.viewer.layers.selection.active)
|
611
|
+
self.viewer.layers.selection.events.active.connect(
|
612
|
+
self._update_export_button_state
|
613
|
+
)
|
614
|
+
|
615
|
+
def _get_save_path(self, event):
|
616
|
+
options = QFileDialog.Options()
|
617
|
+
path, _ = QFileDialog.getSaveFileName(
|
618
|
+
self.native,
|
619
|
+
"Save As...",
|
620
|
+
"",
|
621
|
+
"MRC Files (*.mrc)",
|
622
|
+
options=options,
|
623
|
+
)
|
624
|
+
if path:
|
625
|
+
self.selected_filename = path
|
626
|
+
self._export_data()
|
627
|
+
|
628
|
+
def _update_export_button_state(self, event):
|
629
|
+
"""Update the enabled state of the export button based on the active layer."""
|
630
|
+
self.export_button.enabled = bool(self.viewer.layers.selection.active)
|
631
|
+
|
632
|
+
def _export_data(self):
|
633
|
+
selected_layer = self.viewer.layers.selection.active
|
634
|
+
if selected_layer and isinstance(selected_layer, Image):
|
635
|
+
selected_layer.metadata["write_gzip"] = self.gzip_output.value
|
636
|
+
selected_layer.save(path=self.selected_filename)
|
637
|
+
|
638
|
+
|
639
|
+
class PointCloudWidget(widgets.Container):
|
640
|
+
def __init__(self, viewer):
|
641
|
+
super().__init__(layout="vertical")
|
642
|
+
|
643
|
+
self.viewer = viewer
|
644
|
+
self.dataframes = {}
|
645
|
+
|
646
|
+
self.import_button = widgets.PushButton(
|
647
|
+
name="Import", text="Import Point Cloud"
|
648
|
+
)
|
649
|
+
self.import_button.clicked.connect(self._get_load_path)
|
650
|
+
|
651
|
+
self.export_button = widgets.PushButton(
|
652
|
+
name="Export", text="Export Point Cloud"
|
653
|
+
)
|
654
|
+
self.export_button.clicked.connect(self._export_point_cloud)
|
655
|
+
self.export_button.enabled = False
|
656
|
+
|
657
|
+
self.append(self.import_button)
|
658
|
+
self.append(self.export_button)
|
659
|
+
self.viewer.layers.selection.events.changed.connect(self._update_buttons)
|
660
|
+
|
661
|
+
def _update_buttons(self, event):
|
662
|
+
is_pointcloud = isinstance(
|
663
|
+
self.viewer.layers.selection.active, napari.layers.Points
|
664
|
+
)
|
665
|
+
if self.viewer.layers.selection.active and is_pointcloud:
|
666
|
+
self.export_button.enabled = True
|
667
|
+
else:
|
668
|
+
self.export_button.enabled = False
|
669
|
+
|
670
|
+
def _export_point_cloud(self, event):
|
671
|
+
options = QFileDialog.Options()
|
672
|
+
filename, _ = QFileDialog.getSaveFileName(
|
673
|
+
self.native,
|
674
|
+
"Save Point Cloud File...",
|
675
|
+
"",
|
676
|
+
"TSV Files (*.tsv);;All Files (*)",
|
677
|
+
options=options,
|
678
|
+
)
|
679
|
+
|
680
|
+
if filename:
|
681
|
+
layer = self.viewer.layers.selection.active
|
682
|
+
if layer and isinstance(layer, napari.layers.Points):
|
683
|
+
original_dataframe = self.dataframes.get(layer.name, pd.DataFrame())
|
684
|
+
|
685
|
+
export_data = pd.DataFrame(layer.data, columns=["z", "y", "x"])
|
686
|
+
merged_data = pd.merge(
|
687
|
+
export_data, original_dataframe, on=["z", "y", "x"], how="left"
|
688
|
+
)
|
689
|
+
merged_data.to_csv(filename, sep="\t", index=False)
|
690
|
+
|
691
|
+
def _get_load_path(self, event):
|
692
|
+
options = QFileDialog.Options()
|
693
|
+
filename, _ = QFileDialog.getOpenFileName(
|
694
|
+
self.native,
|
695
|
+
"Open Point Cloud File...",
|
696
|
+
"",
|
697
|
+
"TSV Files (*.tsv);;All Files (*)",
|
698
|
+
options=options,
|
699
|
+
)
|
700
|
+
if filename:
|
701
|
+
self._load_point_cloud(filename)
|
702
|
+
|
703
|
+
def _load_point_cloud(self, filename):
|
704
|
+
dataframe = pd.read_csv(filename, sep="\t")
|
705
|
+
points = dataframe[["z", "y", "x"]].values
|
706
|
+
layer_name = filename.split("/")[-1]
|
707
|
+
self.viewer.add_points(points, size=10, name=layer_name)
|
708
|
+
self.dataframes[layer_name] = dataframe
|
709
|
+
|
710
|
+
|
711
|
+
def main():
|
712
|
+
viewer = napari.Viewer()
|
713
|
+
|
714
|
+
filter_widget = FilterWidget(preprocessor, viewer)
|
715
|
+
mask_widget = MaskWidget(viewer)
|
716
|
+
export_widget = ExportWidget(viewer)
|
717
|
+
point_cloud = PointCloudWidget(viewer)
|
718
|
+
|
719
|
+
viewer.window.add_dock_widget(widget=filter_widget, name="Preprocess", area="right")
|
720
|
+
viewer.window.add_dock_widget(widget=mask_widget, name="Mask", area="right")
|
721
|
+
viewer.window.add_dock_widget(widget=point_cloud, name="PointCloud", area="left")
|
722
|
+
|
723
|
+
viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
|
724
|
+
|
725
|
+
napari.run()
|
726
|
+
|
727
|
+
|
728
|
+
if __name__ == "__main__":
|
729
|
+
main()
|