pytme 0.1.5__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.1.5.data/scripts/estimate_ram_usage.py +81 -0
  2. pytme-0.1.5.data/scripts/match_template.py +744 -0
  3. pytme-0.1.5.data/scripts/postprocess.py +279 -0
  4. pytme-0.1.5.data/scripts/preprocess.py +93 -0
  5. pytme-0.1.5.data/scripts/preprocessor_gui.py +729 -0
  6. pytme-0.1.5.dist-info/LICENSE +153 -0
  7. pytme-0.1.5.dist-info/METADATA +69 -0
  8. pytme-0.1.5.dist-info/RECORD +63 -0
  9. pytme-0.1.5.dist-info/WHEEL +5 -0
  10. pytme-0.1.5.dist-info/entry_points.txt +6 -0
  11. pytme-0.1.5.dist-info/top_level.txt +2 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +81 -0
  14. scripts/match_template.py +744 -0
  15. scripts/match_template_devel.py +788 -0
  16. scripts/postprocess.py +279 -0
  17. scripts/preprocess.py +93 -0
  18. scripts/preprocessor_gui.py +729 -0
  19. tme/__init__.py +6 -0
  20. tme/__version__.py +1 -0
  21. tme/analyzer.py +1144 -0
  22. tme/backends/__init__.py +134 -0
  23. tme/backends/cupy_backend.py +309 -0
  24. tme/backends/matching_backend.py +1154 -0
  25. tme/backends/npfftw_backend.py +763 -0
  26. tme/backends/pytorch_backend.py +526 -0
  27. tme/data/__init__.py +0 -0
  28. tme/data/c48n309.npy +0 -0
  29. tme/data/c48n527.npy +0 -0
  30. tme/data/c48n9.npy +0 -0
  31. tme/data/c48u1.npy +0 -0
  32. tme/data/c48u1153.npy +0 -0
  33. tme/data/c48u1201.npy +0 -0
  34. tme/data/c48u1641.npy +0 -0
  35. tme/data/c48u181.npy +0 -0
  36. tme/data/c48u2219.npy +0 -0
  37. tme/data/c48u27.npy +0 -0
  38. tme/data/c48u2947.npy +0 -0
  39. tme/data/c48u3733.npy +0 -0
  40. tme/data/c48u4749.npy +0 -0
  41. tme/data/c48u5879.npy +0 -0
  42. tme/data/c48u7111.npy +0 -0
  43. tme/data/c48u815.npy +0 -0
  44. tme/data/c48u83.npy +0 -0
  45. tme/data/c48u8649.npy +0 -0
  46. tme/data/c600v.npy +0 -0
  47. tme/data/c600vc.npy +0 -0
  48. tme/data/metadata.yaml +80 -0
  49. tme/data/quat_to_numpy.py +42 -0
  50. tme/data/scattering_factors.pickle +0 -0
  51. tme/density.py +2314 -0
  52. tme/extensions.cpython-311-darwin.so +0 -0
  53. tme/helpers.py +881 -0
  54. tme/matching_data.py +377 -0
  55. tme/matching_exhaustive.py +1553 -0
  56. tme/matching_memory.py +382 -0
  57. tme/matching_optimization.py +1123 -0
  58. tme/matching_utils.py +1180 -0
  59. tme/parser.py +429 -0
  60. tme/preprocessor.py +1291 -0
  61. tme/scoring.py +866 -0
  62. tme/structure.py +1428 -0
  63. tme/types.py +10 -0
@@ -0,0 +1,729 @@
1
+ #!python3
2
+ """ Simplify picking adequate filtering and masking parameters using a GUI.
3
+ Exposes tme.preprocessor.Preprocessor and tme.fitter_utils member functions
4
+ to achieve this aim.
5
+
6
+ Copyright (c) 2023 European Molecular Biology Laboratory
7
+
8
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
9
+ """
10
+ import inspect
11
+ from typing import Tuple, Callable, List
12
+ from typing_extensions import Annotated
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ import napari
17
+ from napari.layers import Image
18
+ from napari.utils.events import EventedList
19
+
20
+ from magicgui import widgets
21
+ from qtpy.QtWidgets import QFileDialog
22
+ from numpy.typing import NDArray
23
+
24
+ from tme import Preprocessor, Density
25
+ from tme.matching_utils import create_mask
26
+
27
+ preprocessor = Preprocessor()
28
+ SLIDER_MIN, SLIDER_MAX = 0, 25
29
+
30
+
31
+ def gaussian_filter(template, sigma: float, **kwargs: dict):
32
+ return preprocessor.gaussian_filter(template=template, sigma=sigma, **kwargs)
33
+
34
+
35
+ def bandpass_filter(
36
+ template,
37
+ minimum_frequency: float,
38
+ maximum_frequency: float,
39
+ gaussian_sigma: float,
40
+ **kwargs: dict,
41
+ ):
42
+ return preprocessor.bandpass_filter(
43
+ template=template,
44
+ minimum_frequency=minimum_frequency,
45
+ maximum_frequency=maximum_frequency,
46
+ sampling_rate=1,
47
+ gaussian_sigma=gaussian_sigma,
48
+ **kwargs,
49
+ )
50
+
51
+
52
+ def difference_of_gaussian_filter(
53
+ template, sigmas: Tuple[float, float], **kwargs: dict
54
+ ):
55
+ low_sigma, high_sigma = sigmas
56
+ return preprocessor.difference_of_gaussian_filter(
57
+ template=template, low_sigma=low_sigma, high_sigma=high_sigma, **kwargs
58
+ )
59
+
60
+
61
+ def edge_gaussian_filter(
62
+ template,
63
+ sigma: float,
64
+ edge_algorithm: Annotated[
65
+ str,
66
+ {"choices": ["sobel", "prewitt", "laplace", "gaussian", "gaussian_laplace"]},
67
+ ],
68
+ reverse: bool = False,
69
+ **kwargs: dict,
70
+ ):
71
+ return preprocessor.edge_gaussian_filter(
72
+ template=template,
73
+ sigma=sigma,
74
+ reverse=reverse,
75
+ edge_algorithm=edge_algorithm,
76
+ )
77
+
78
+
79
+ def local_gaussian_filter(
80
+ template,
81
+ lbd: float,
82
+ sigma_range: Tuple[float, float],
83
+ gaussian_sigma: float,
84
+ reverse: bool = False,
85
+ **kwargs: dict,
86
+ ):
87
+ return preprocessor.local_gaussian_filter(
88
+ template=template,
89
+ lbd=lbd,
90
+ sigma_range=sigma_range,
91
+ gaussian_sigma=gaussian_sigma,
92
+ )
93
+
94
+
95
+ def ntree(
96
+ template,
97
+ sigma_range: Tuple[float, float],
98
+ **kwargs: dict,
99
+ ):
100
+ return preprocessor.ntree_filter(template=template, sigma_range=sigma_range)
101
+
102
+
103
+ def mean(
104
+ template,
105
+ width: int,
106
+ **kwargs: dict,
107
+ ):
108
+ return preprocessor.mean_filter(template=template, width=width)
109
+
110
+
111
+ def wedge(
112
+ template: NDArray,
113
+ tilt_start: float,
114
+ tilt_stop: float,
115
+ gaussian_sigma: float,
116
+ tilt_axis: int = 1,
117
+ infinite_plane: bool = True,
118
+ extrude_plane: bool = True,
119
+ ):
120
+ template_ft = np.fft.rfftn(template)
121
+ wedge_mask = preprocessor.continuous_wedge_mask(
122
+ start_tilt=tilt_start,
123
+ stop_tilt=tilt_stop,
124
+ tilt_axis=tilt_axis,
125
+ shape=template.shape,
126
+ sigma=gaussian_sigma,
127
+ omit_negative_frequencies=True,
128
+ infinite_plane=infinite_plane,
129
+ extrude_plane=extrude_plane,
130
+ )
131
+ np.multiply(template_ft, wedge_mask, out=template_ft)
132
+ template = np.real(np.fft.irfftn(template_ft))
133
+ return template
134
+
135
+
136
+ def widgets_from_function(function: Callable, exclude_params: List = ["self"]):
137
+ """
138
+ Creates list of magicui widgets by inspecting function typing ann
139
+ """
140
+ ret = []
141
+ for name, param in inspect.signature(function).parameters.items():
142
+ if name in exclude_params:
143
+ continue
144
+
145
+ if param.annotation is float:
146
+ widget = widgets.FloatSpinBox(
147
+ name=name,
148
+ value=param.default if param.default != inspect._empty else 0,
149
+ min=SLIDER_MIN,
150
+ step=0.5,
151
+ )
152
+ elif param.annotation == Tuple[float, float]:
153
+ widget = widgets.FloatRangeSlider(
154
+ name=param.name,
155
+ value=param.default
156
+ if param.default != inspect._empty
157
+ else (0.0, SLIDER_MAX / 2),
158
+ min=SLIDER_MIN,
159
+ max=SLIDER_MAX,
160
+ )
161
+ elif param.annotation is int:
162
+ widget = widgets.SpinBox(
163
+ name=name,
164
+ value=param.default if param.default != inspect._empty else 0,
165
+ )
166
+ elif param.annotation is bool:
167
+ widget = widgets.CheckBox(
168
+ name=name,
169
+ value=param.default if param.default != inspect._empty else False,
170
+ )
171
+ elif hasattr(param.annotation, "__metadata__"):
172
+ metadata = param.annotation.__metadata__[0]
173
+ if "choices" in metadata:
174
+ widget = widgets.ComboBox(
175
+ name=param.name,
176
+ choices=metadata["choices"],
177
+ value=param.default
178
+ if param.default != inspect._empty
179
+ else metadata["choices"][0],
180
+ )
181
+ else:
182
+ continue
183
+ ret.append(widget)
184
+ return ret
185
+
186
+
187
+ WRAPPED_FUNCTIONS = {
188
+ "gaussian_filter": gaussian_filter,
189
+ "bandpass_filter": bandpass_filter,
190
+ "edge_gaussian_filter": edge_gaussian_filter,
191
+ "ntree_filter": ntree,
192
+ "local_gaussian_filter": local_gaussian_filter,
193
+ "difference_of_gaussian_filter": difference_of_gaussian_filter,
194
+ "mean_filter": mean,
195
+ "continuous_wedge_mask": wedge,
196
+ }
197
+
198
+ EXCLUDED_FUNCTIONS = [
199
+ "apply_method",
200
+ "method_to_id",
201
+ "wedge_mask",
202
+ "fourier_crop",
203
+ "fourier_uncrop",
204
+ "interpolate_box",
205
+ "molmap",
206
+ "local_gaussian_alignment_filter",
207
+ # "continuous_wedge_mask",
208
+ "wedge_mask",
209
+ "bandpass_mask",
210
+ ]
211
+
212
+
213
+ class FilterWidget(widgets.Container):
214
+ def __init__(self, preprocessor, viewer):
215
+ super().__init__(layout="vertical")
216
+
217
+ self.preprocessor = preprocessor
218
+ self.viewer = viewer
219
+ self.name_mapping = {}
220
+ self.action_widgets = []
221
+
222
+ self.layer_dropdown = widgets.ComboBox(
223
+ name="Target Layer", choices=self._get_layer_names()
224
+ )
225
+ self.append(self.layer_dropdown)
226
+ self.viewer.layers.events.inserted.connect(self._update_layer_dropdown)
227
+ self.viewer.layers.events.removed.connect(self._update_layer_dropdown)
228
+
229
+ self.method_dropdown = widgets.ComboBox(
230
+ name="Choose Filter", choices=self._get_method_names()
231
+ )
232
+ self.method_dropdown.changed.connect(self._on_method_changed)
233
+ self.append(self.method_dropdown)
234
+
235
+ self.apply_btn = widgets.PushButton(text="Apply Filter", enabled=False)
236
+ self.apply_btn.changed.connect(self._action)
237
+ self.append(self.apply_btn)
238
+
239
+ # Create GUI for initially selected filtering method
240
+ self._on_method_changed(None)
241
+
242
+ def _get_method_names(self):
243
+ method_names = [
244
+ name
245
+ for name, member in inspect.getmembers(self.preprocessor, inspect.ismethod)
246
+ if not name.startswith("_") and name not in EXCLUDED_FUNCTIONS
247
+ ]
248
+
249
+ sanitized_names = [self._sanitize_name(name) for name in method_names]
250
+ self.name_mapping.update(dict(zip(sanitized_names, method_names)))
251
+
252
+ return sanitized_names
253
+
254
+ def _sanitize_name(self, name: str) -> str:
255
+ # Replace underscores with spaces and capitalize each word
256
+ removes = ["blur", "filter"]
257
+ for remove in removes:
258
+ name = name.replace(remove, "")
259
+ return name.strip().replace("_", " ").title()
260
+
261
+ def _desanitize_name(self, name: str) -> str:
262
+ name = name.lower().strip()
263
+ for function_name, _ in inspect.getmembers(self.preprocessor, inspect.ismethod):
264
+ if function_name.startswith(name):
265
+ return function_name
266
+ return name
267
+
268
+ def _get_function(self, name: str):
269
+ function = WRAPPED_FUNCTIONS.get(name, None)
270
+ if not function:
271
+ function = getattr(self.preprocessor, name, None)
272
+ return function
273
+
274
+ def _on_method_changed(self, event=None):
275
+ # Clear previous parameter widgets
276
+ for widget in self.action_widgets:
277
+ self.remove(widget)
278
+ self.action_widgets.clear()
279
+
280
+ function_name = self.name_mapping.get(self.method_dropdown.value)
281
+ function = self._get_function(function_name)
282
+
283
+ widgets = widgets_from_function(function, exclude_params=["self", "template"])
284
+ for widget in widgets:
285
+ self.action_widgets.append(widget)
286
+ self.insert(-1, widget)
287
+
288
+ def _update_layer_dropdown(self, event: EventedList):
289
+ """Update the dropdown menu when layers change."""
290
+ self.layer_dropdown.choices = self._get_layer_names()
291
+ self.apply_btn.enabled = bool(self.viewer.layers)
292
+
293
+ def _get_layer_names(self):
294
+ """Return list of layer names in the viewer."""
295
+ return sorted([layer.name for layer in self.viewer.layers])
296
+
297
+ def _action(self, event):
298
+ selected_layer = self.viewer.layers[self.layer_dropdown.value]
299
+ selected_layer_metadata = selected_layer.metadata.copy()
300
+ kwargs = {widget.name: widget.value for widget in self.action_widgets}
301
+
302
+ function_name = self.name_mapping.get(self.method_dropdown.value)
303
+ function = self._get_function(function_name)
304
+
305
+ processed_data = function(selected_layer.data, **kwargs)
306
+
307
+ new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
308
+
309
+ if new_layer_name in self.viewer.layers:
310
+ selected_layer = self.viewer.layers[new_layer_name]
311
+
312
+ filter_name = self._desanitize_name(self.method_dropdown.value)
313
+ used_filter = selected_layer.metadata.get("used_filter", False)
314
+ if used_filter == filter_name:
315
+ selected_layer.data = processed_data
316
+ else:
317
+ new_layer = self.viewer.add_image(
318
+ data=processed_data,
319
+ name=new_layer_name,
320
+ )
321
+ metadata = selected_layer_metadata.copy()
322
+ if "filter_parameters" not in metadata:
323
+ metadata["filter_parameters"] = []
324
+ metadata["filter_parameters"].append({filter_name: kwargs.copy()})
325
+ metadata["used_filter"] = filter_name
326
+ new_layer.metadata = metadata
327
+
328
+
329
+ def sphere_mask(
330
+ template: NDArray, center_x: float, center_y: float, center_z: float, radius: float
331
+ ):
332
+ return create_mask(
333
+ mask_type="ellipse",
334
+ shape=template.shape,
335
+ center=(center_x, center_y, center_z),
336
+ radius=radius,
337
+ )
338
+
339
+
340
+ def ellipsod_mask(
341
+ template: NDArray,
342
+ center_x: float,
343
+ center_y: float,
344
+ center_z: float,
345
+ radius_x: float,
346
+ radius_y: float,
347
+ radius_z: float,
348
+ ):
349
+ return create_mask(
350
+ mask_type="ellipse",
351
+ shape=template.shape,
352
+ center=(center_x, center_y, center_z),
353
+ radius=(radius_x, radius_y, radius_z),
354
+ )
355
+
356
+
357
+ def box_mask(
358
+ template: NDArray,
359
+ center_x: float,
360
+ center_y: float,
361
+ center_z: float,
362
+ height_x: int,
363
+ height_y: int,
364
+ height_z: int,
365
+ ):
366
+ return create_mask(
367
+ mask_type="box",
368
+ shape=template.shape,
369
+ center=(center_x, center_y, center_z),
370
+ height=(height_x, height_y, height_z),
371
+ )
372
+
373
+
374
+ def tube_mask(
375
+ template: NDArray,
376
+ symmetry_axis: int,
377
+ center_x: float,
378
+ center_y: float,
379
+ center_z: float,
380
+ inner_radius: float,
381
+ outer_radius: float,
382
+ height: int,
383
+ ):
384
+ return create_mask(
385
+ mask_type="tube",
386
+ shape=template.shape,
387
+ symmetry_axis=symmetry_axis,
388
+ base_center=(center_x, center_y, center_z),
389
+ inner_radius=inner_radius,
390
+ outer_radius=outer_radius,
391
+ height=height,
392
+ )
393
+
394
+
395
+ def wedge_mask(
396
+ template: NDArray,
397
+ tilt_start: float,
398
+ tilt_stop: float,
399
+ gaussian_sigma: float,
400
+ tilt_axis: int = 1,
401
+ omit_negative_frequencies: bool = True,
402
+ extrude_plane: bool = True,
403
+ infinite_plane: bool = True,
404
+ ):
405
+ wedge_mask = preprocessor.continuous_wedge_mask(
406
+ start_tilt=tilt_start,
407
+ stop_tilt=tilt_stop,
408
+ tilt_axis=tilt_axis,
409
+ shape=template.shape,
410
+ sigma=gaussian_sigma,
411
+ omit_negative_frequencies=omit_negative_frequencies,
412
+ extrude_plane=extrude_plane,
413
+ infinite_plane=infinite_plane,
414
+ )
415
+ wedge_mask = np.fft.fftshift(wedge_mask)
416
+ return wedge_mask
417
+
418
+ def threshold_mask(
419
+ template: NDArray,
420
+ standard_deviation : float = 5.0,
421
+ invert : bool = False
422
+ ):
423
+ template_mean = template.mean()
424
+ template_deviation = standard_deviation * template.std()
425
+ upper = template_mean + template_deviation
426
+ lower = template_mean - template_deviation
427
+ mask = np.logical_and(template > lower, template < upper)
428
+ if invert:
429
+ np.invert(mask, out = mask)
430
+
431
+ return mask
432
+
433
+
434
+
435
+ class MaskWidget(widgets.Container):
436
+ def __init__(self, viewer):
437
+ super().__init__(layout="vertical")
438
+
439
+ self.viewer = viewer
440
+ self.action_widgets = []
441
+
442
+ self.action_button = widgets.PushButton(text="Create mask", enabled=False)
443
+ self.action_button.changed.connect(self._action)
444
+
445
+ self.methods = {
446
+ "Sphere": sphere_mask,
447
+ "Ellipsoid": ellipsod_mask,
448
+ "Tube": tube_mask,
449
+ "Box": box_mask,
450
+ "Wedge": wedge_mask,
451
+ "Threshold" : threshold_mask
452
+ }
453
+
454
+ self.method_dropdown = widgets.ComboBox(
455
+ name="Choose Mask", choices=list(self.methods.keys())
456
+ )
457
+ self.method_dropdown.changed.connect(self._on_method_changed)
458
+
459
+ self.adapt_button = widgets.PushButton(text="Adapt to layer", enabled=False)
460
+ self.adapt_button.changed.connect(self._update_initial_values)
461
+
462
+ self.viewer.layers.selection.events.active.connect(
463
+ self._update_action_button_state
464
+ )
465
+
466
+ self.align_button = widgets.PushButton(text="Align to axis", enabled=False)
467
+ self.align_button.changed.connect(self._align_with_axis)
468
+ self.density_field = widgets.Label()
469
+ # self.density_field.value = f"Positive Density in Mask: {0:.2f}%"
470
+
471
+ self.append(self.method_dropdown)
472
+ self.append(self.adapt_button)
473
+ self.append(self.align_button)
474
+ self.append(self.action_button)
475
+ self.append(self.density_field)
476
+
477
+ # Create GUI for initially selected filtering method
478
+ self._on_method_changed(None)
479
+
480
+ def _update_action_button_state(self, event):
481
+ self.align_button.enabled = bool(self.viewer.layers.selection.active)
482
+ self.action_button.enabled = bool(self.viewer.layers.selection.active)
483
+ self.adapt_button.enabled = bool(self.viewer.layers.selection.active)
484
+
485
+ def _align_with_axis(self):
486
+ active_layer = self.viewer.layers.selection.active
487
+
488
+ if active_layer.metadata.get("is_aligned", False):
489
+ return
490
+
491
+ coords = np.array(np.where(active_layer.data > 0)).T
492
+ centered_coords = coords - np.mean(coords, axis=0)
493
+ cov_matrix = np.cov(centered_coords, rowvar=False)
494
+
495
+ eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
496
+ principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
497
+
498
+ rotation_axis = np.cross(principal_eigenvector, [1, 0, 0])
499
+ rotation_angle = np.arccos(np.dot(principal_eigenvector, [1, 0, 0]))
500
+ k = rotation_axis / np.linalg.norm(rotation_axis)
501
+ K = np.array([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]])
502
+ rotation_matrix = np.eye(3)
503
+ rotation_matrix += np.sin(rotation_angle) * K
504
+ rotation_matrix += (1 - np.cos(rotation_angle)) * np.dot(K, K)
505
+
506
+ rotated_data = Density.rotate_array(
507
+ arr=active_layer.data,
508
+ rotation_matrix=rotation_matrix,
509
+ use_geometric_center=False,
510
+ )
511
+ eps = np.finfo(rotated_data.dtype).eps
512
+ rotated_data[rotated_data < eps] = 0
513
+
514
+ active_layer.metadata["is_aligned"] = True
515
+ active_layer.data = rotated_data
516
+
517
+ def _update_initial_values(self, event=None):
518
+ active_layer = self.viewer.layers.selection.active
519
+ center_of_mass = Density.center_of_mass(np.abs(active_layer.data), 0)
520
+ coordinates = np.array(np.where(active_layer.data > 0))
521
+ coordinates_min = coordinates.min(axis=1)
522
+ coordinates_max = coordinates.max(axis=1)
523
+ coordinates_heights = coordinates_max - coordinates_min
524
+ coordinate_radius = np.divide(coordinates_heights, 2)
525
+ center_of_mass = coordinate_radius + coordinates_min
526
+
527
+ defaults = dict(zip(["center_x", "center_y", "center_z"], center_of_mass))
528
+ defaults.update(
529
+ dict(zip(["radius_x", "radius_y", "radius_z"], coordinate_radius))
530
+ )
531
+ defaults.update(
532
+ dict(zip(["height_x", "height_y", "height_z"], coordinates_heights))
533
+ )
534
+
535
+ defaults["radius"] = np.min(coordinate_radius)
536
+ defaults["inner_radius"] = np.min(coordinate_radius)
537
+ defaults["outer_radius"] = np.max(coordinate_radius)
538
+ defaults["height"] = defaults["radius"]
539
+
540
+ for widget in self.action_widgets:
541
+ if widget.name in defaults:
542
+ widget.value = defaults[widget.name]
543
+
544
+ def _on_method_changed(self, event=None):
545
+ for widget in self.action_widgets:
546
+ self.remove(widget)
547
+ self.action_widgets.clear()
548
+
549
+ function = self.methods.get(self.method_dropdown.value)
550
+ widgets = widgets_from_function(function)
551
+ for widget in widgets:
552
+ self.action_widgets.append(widget)
553
+ self.insert(1, widget)
554
+
555
+ def _action(self):
556
+ function = self.methods.get(self.method_dropdown.value)
557
+
558
+ selected_layer = self.viewer.layers.selection.active
559
+ kwargs = {widget.name: widget.value for widget in self.action_widgets}
560
+ processed_data = function(template=selected_layer.data, **kwargs)
561
+ new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
562
+
563
+ if new_layer_name in self.viewer.layers:
564
+ selected_layer = self.viewer.layers[new_layer_name]
565
+
566
+ processed_data = processed_data.astype(np.float32)
567
+ metadata = selected_layer.metadata
568
+ mask = metadata.get("mask", False)
569
+ if mask == self.method_dropdown.value:
570
+ selected_layer.data = processed_data
571
+ else:
572
+ new_layer = self.viewer.add_image(
573
+ data=processed_data,
574
+ name=new_layer_name,
575
+ )
576
+ metadata = selected_layer.metadata.copy()
577
+ metadata["filter_parameters"] = {self.method_dropdown.value: kwargs.copy()}
578
+ metadata["mask"] = self.method_dropdown.value
579
+ metadata["origin_layer"] = selected_layer.name
580
+ new_layer.metadata = metadata
581
+
582
+ origin_layer = metadata["origin_layer"]
583
+ if origin_layer in self.viewer.layers:
584
+ origin_layer = self.viewer.layers[origin_layer]
585
+ if np.allclose(origin_layer.data.shape, processed_data.shape):
586
+ in_mask = np.sum(np.fmax(origin_layer.data * processed_data, 0))
587
+ in_mask /= np.sum(np.fmax(origin_layer.data, 0))
588
+ in_mask *= 100
589
+ self.density_field.value = f"Positive Density in Mask: {in_mask:.2f}%"
590
+
591
+
592
+ class ExportWidget(widgets.Container):
593
+ def __init__(self, viewer):
594
+ super().__init__(layout="vertical")
595
+
596
+ self.viewer = viewer
597
+ self.selected_filename = ""
598
+
599
+ horizontal_container = widgets.Container(layout="horizontal")
600
+
601
+ self.gzip_output = widgets.CheckBox(name="gzip", value=False, label="gzip")
602
+ self.export_button = widgets.PushButton(name="Export", text="Export")
603
+ self.export_button.clicked.connect(self._get_save_path)
604
+
605
+ horizontal_container.append(self.export_button)
606
+ horizontal_container.append(self.gzip_output)
607
+
608
+ self.append(horizontal_container)
609
+
610
+ self.export_button.enabled = bool(self.viewer.layers.selection.active)
611
+ self.viewer.layers.selection.events.active.connect(
612
+ self._update_export_button_state
613
+ )
614
+
615
+ def _get_save_path(self, event):
616
+ options = QFileDialog.Options()
617
+ path, _ = QFileDialog.getSaveFileName(
618
+ self.native,
619
+ "Save As...",
620
+ "",
621
+ "MRC Files (*.mrc)",
622
+ options=options,
623
+ )
624
+ if path:
625
+ self.selected_filename = path
626
+ self._export_data()
627
+
628
+ def _update_export_button_state(self, event):
629
+ """Update the enabled state of the export button based on the active layer."""
630
+ self.export_button.enabled = bool(self.viewer.layers.selection.active)
631
+
632
+ def _export_data(self):
633
+ selected_layer = self.viewer.layers.selection.active
634
+ if selected_layer and isinstance(selected_layer, Image):
635
+ selected_layer.metadata["write_gzip"] = self.gzip_output.value
636
+ selected_layer.save(path=self.selected_filename)
637
+
638
+
639
+ class PointCloudWidget(widgets.Container):
640
+ def __init__(self, viewer):
641
+ super().__init__(layout="vertical")
642
+
643
+ self.viewer = viewer
644
+ self.dataframes = {}
645
+
646
+ self.import_button = widgets.PushButton(
647
+ name="Import", text="Import Point Cloud"
648
+ )
649
+ self.import_button.clicked.connect(self._get_load_path)
650
+
651
+ self.export_button = widgets.PushButton(
652
+ name="Export", text="Export Point Cloud"
653
+ )
654
+ self.export_button.clicked.connect(self._export_point_cloud)
655
+ self.export_button.enabled = False
656
+
657
+ self.append(self.import_button)
658
+ self.append(self.export_button)
659
+ self.viewer.layers.selection.events.changed.connect(self._update_buttons)
660
+
661
+ def _update_buttons(self, event):
662
+ is_pointcloud = isinstance(
663
+ self.viewer.layers.selection.active, napari.layers.Points
664
+ )
665
+ if self.viewer.layers.selection.active and is_pointcloud:
666
+ self.export_button.enabled = True
667
+ else:
668
+ self.export_button.enabled = False
669
+
670
+ def _export_point_cloud(self, event):
671
+ options = QFileDialog.Options()
672
+ filename, _ = QFileDialog.getSaveFileName(
673
+ self.native,
674
+ "Save Point Cloud File...",
675
+ "",
676
+ "TSV Files (*.tsv);;All Files (*)",
677
+ options=options,
678
+ )
679
+
680
+ if filename:
681
+ layer = self.viewer.layers.selection.active
682
+ if layer and isinstance(layer, napari.layers.Points):
683
+ original_dataframe = self.dataframes.get(layer.name, pd.DataFrame())
684
+
685
+ export_data = pd.DataFrame(layer.data, columns=["z", "y", "x"])
686
+ merged_data = pd.merge(
687
+ export_data, original_dataframe, on=["z", "y", "x"], how="left"
688
+ )
689
+ merged_data.to_csv(filename, sep="\t", index=False)
690
+
691
+ def _get_load_path(self, event):
692
+ options = QFileDialog.Options()
693
+ filename, _ = QFileDialog.getOpenFileName(
694
+ self.native,
695
+ "Open Point Cloud File...",
696
+ "",
697
+ "TSV Files (*.tsv);;All Files (*)",
698
+ options=options,
699
+ )
700
+ if filename:
701
+ self._load_point_cloud(filename)
702
+
703
+ def _load_point_cloud(self, filename):
704
+ dataframe = pd.read_csv(filename, sep="\t")
705
+ points = dataframe[["z", "y", "x"]].values
706
+ layer_name = filename.split("/")[-1]
707
+ self.viewer.add_points(points, size=10, name=layer_name)
708
+ self.dataframes[layer_name] = dataframe
709
+
710
+
711
+ def main():
712
+ viewer = napari.Viewer()
713
+
714
+ filter_widget = FilterWidget(preprocessor, viewer)
715
+ mask_widget = MaskWidget(viewer)
716
+ export_widget = ExportWidget(viewer)
717
+ point_cloud = PointCloudWidget(viewer)
718
+
719
+ viewer.window.add_dock_widget(widget=filter_widget, name="Preprocess", area="right")
720
+ viewer.window.add_dock_widget(widget=mask_widget, name="Mask", area="right")
721
+ viewer.window.add_dock_widget(widget=point_cloud, name="PointCloud", area="left")
722
+
723
+ viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
724
+
725
+ napari.run()
726
+
727
+
728
+ if __name__ == "__main__":
729
+ main()