celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +403 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/downloader.py +137 -0
  81. celldetective/processes/measure_cells.py +565 -0
  82. celldetective/processes/segment_cells.py +760 -0
  83. celldetective/processes/track_cells.py +435 -0
  84. celldetective/processes/train_segmentation_model.py +694 -0
  85. celldetective/processes/train_signal_model.py +265 -0
  86. celldetective/processes/unified_process.py +292 -0
  87. celldetective/regionprops/_regionprops.py +358 -317
  88. celldetective/relative_measurements.py +987 -710
  89. celldetective/scripts/measure_cells.py +313 -212
  90. celldetective/scripts/measure_relative.py +90 -46
  91. celldetective/scripts/segment_cells.py +165 -104
  92. celldetective/scripts/segment_cells_thresholds.py +96 -68
  93. celldetective/scripts/track_cells.py +198 -149
  94. celldetective/scripts/train_segmentation_model.py +324 -201
  95. celldetective/scripts/train_signal_model.py +87 -45
  96. celldetective/segmentation.py +844 -749
  97. celldetective/signals.py +3514 -2861
  98. celldetective/tracking.py +30 -15
  99. celldetective/utils/__init__.py +0 -0
  100. celldetective/utils/cellpose_utils/__init__.py +133 -0
  101. celldetective/utils/color_mappings.py +42 -0
  102. celldetective/utils/data_cleaning.py +630 -0
  103. celldetective/utils/data_loaders.py +450 -0
  104. celldetective/utils/dataset_helpers.py +207 -0
  105. celldetective/utils/downloaders.py +235 -0
  106. celldetective/utils/event_detection/__init__.py +8 -0
  107. celldetective/utils/experiment.py +1782 -0
  108. celldetective/utils/image_augmenters.py +308 -0
  109. celldetective/utils/image_cleaning.py +74 -0
  110. celldetective/utils/image_loaders.py +926 -0
  111. celldetective/utils/image_transforms.py +335 -0
  112. celldetective/utils/io.py +62 -0
  113. celldetective/utils/mask_cleaning.py +348 -0
  114. celldetective/utils/mask_transforms.py +5 -0
  115. celldetective/utils/masks.py +184 -0
  116. celldetective/utils/maths.py +351 -0
  117. celldetective/utils/model_getters.py +325 -0
  118. celldetective/utils/model_loaders.py +296 -0
  119. celldetective/utils/normalization.py +380 -0
  120. celldetective/utils/parsing.py +465 -0
  121. celldetective/utils/plots/__init__.py +0 -0
  122. celldetective/utils/plots/regression.py +53 -0
  123. celldetective/utils/resources.py +34 -0
  124. celldetective/utils/stardist_utils/__init__.py +104 -0
  125. celldetective/utils/stats.py +90 -0
  126. celldetective/utils/types.py +21 -0
  127. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
  128. celldetective-1.5.0b1.dist-info/RECORD +187 -0
  129. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
  130. tests/gui/test_new_project.py +129 -117
  131. tests/gui/test_project.py +127 -79
  132. tests/test_filters.py +39 -15
  133. tests/test_notebooks.py +8 -0
  134. tests/test_tracking.py +232 -13
  135. tests/test_utils.py +123 -77
  136. celldetective/gui/base_components.py +0 -23
  137. celldetective/gui/layouts.py +0 -1602
  138. celldetective/gui/processes/compute_neighborhood.py +0 -594
  139. celldetective/gui/processes/downloader.py +0 -111
  140. celldetective/gui/processes/measure_cells.py +0 -360
  141. celldetective/gui/processes/segment_cells.py +0 -499
  142. celldetective/gui/processes/track_cells.py +0 -303
  143. celldetective/gui/processes/train_segmentation_model.py +0 -270
  144. celldetective/gui/processes/train_signal_model.py +0 -108
  145. celldetective/gui/table_ops/merge_groups.py +0 -118
  146. celldetective/gui/viewers.py +0 -1354
  147. celldetective/io.py +0 -3663
  148. celldetective/utils.py +0 -3108
  149. celldetective-1.4.2.dist-info/RECORD +0 -123
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
  152. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1025 @@
1
+ import gc
2
+ import json
3
+ import os
4
+ from pathlib import Path, PurePath
5
+
6
+ import napari
7
+ import numpy as np
8
+ import pandas as pd
9
+ from PyQt5.QtWidgets import QMessageBox, QWidget, QVBoxLayout
10
+ from celldetective.utils.io import save_tiff_imagej_compatible
11
+ from magicgui import magicgui
12
+ from skimage.measure import regionprops_table
13
+ from tifffile import imread
14
+ from tqdm import tqdm
15
+
16
+ from celldetective.utils.data_cleaning import tracks_to_btrack
17
+ from celldetective.utils.mask_cleaning import auto_correct_masks, relabel_segmentation
18
+ from celldetective.utils.image_loaders import locate_labels, locate_stack_and_labels
19
+ from celldetective.utils.data_loaders import get_position_table, load_tracking_data
20
+ from celldetective.utils.experiment import (
21
+ extract_experiment_from_position,
22
+ _get_contrast_limits,
23
+ get_experiment_wells,
24
+ get_experiment_labels,
25
+ get_experiment_metadata,
26
+ extract_experiment_channels,
27
+ )
28
+ from celldetective.utils.parsing import config_section_to_dict
29
+ from celldetective import get_logger
30
+
31
+ logger = get_logger()
32
+
33
+
34
+ def control_tracks(
35
+ position,
36
+ prefix="Aligned",
37
+ population="target",
38
+ relabel=True,
39
+ flush_memory=True,
40
+ threads=1,
41
+ progress_callback=None,
42
+ prepare_only=False,
43
+ ):
44
+ """
45
+ Controls the tracking of cells or objects within a given position by locating the relevant image stack and label data,
46
+ and then visualizing and managing the tracks in the Napari viewer.
47
+
48
+ Parameters
49
+ ----------
50
+ position : str
51
+ The path to the directory containing the position's data. The function will ensure the path uses forward slashes.
52
+
53
+ prefix : str, optional, default="Aligned"
54
+ The prefix of the file names for the image stack and labels. This parameter helps locate the relevant data files.
55
+
56
+ population : str, optional, default="target"
57
+ The population to be tracked, typically either "target" or "effectors". This is used to identify the group of interest for tracking.
58
+
59
+ relabel : bool, optional, default=True
60
+ If True, will relabel the tracks, potentially assigning new track IDs to the detected objects.
61
+
62
+ flush_memory : bool, optional, default=True
63
+ If True, will flush memory after processing to free up resources.
64
+
65
+ threads : int, optional, default=1
66
+ The number of threads to use for processing. This can speed up the task in multi-threaded environments.
67
+
68
+ Returns
69
+ -------
70
+ None
71
+ The function performs visualization and management of tracks in the Napari viewer. It does not return any value.
72
+
73
+ Notes
74
+ -----
75
+ - This function assumes that the necessary data for tracking (stack and labels) are located in the specified position directory.
76
+ - The `locate_stack_and_labels` function is used to retrieve the image stack and labels from the specified directory.
77
+ - The tracks are visualized using the `view_tracks_in_napari` function, which handles the display in the Napari viewer.
78
+ - The function can be used for tracking biological entities (e.g., cells) and their movement across time frames in an image stack.
79
+
80
+ Example
81
+ -------
82
+ >>> control_tracks("/path/to/data/position_1", prefix="Aligned", population="target", relabel=True, flush_memory=True, threads=4)
83
+
84
+ """
85
+
86
+ if not position.endswith(os.sep):
87
+ position += os.sep
88
+
89
+ position = position.replace("\\", "/")
90
+ stack, labels = locate_stack_and_labels(
91
+ position, prefix=prefix, population=population
92
+ )
93
+
94
+ return view_tracks_in_napari(
95
+ position,
96
+ population,
97
+ labels=labels,
98
+ stack=stack,
99
+ relabel=relabel,
100
+ flush_memory=flush_memory,
101
+ threads=threads,
102
+ progress_callback=progress_callback,
103
+ prepare_only=prepare_only,
104
+ )
105
+
106
+
107
+ def tracks_to_napari(df, exclude_nans=False):
108
+
109
+ data, properties, graph = tracks_to_btrack(df, exclude_nans=exclude_nans)
110
+ vertices = data[:, [1, -2, -1]]
111
+ if data.shape[1] == 4:
112
+ tracks = data
113
+ else:
114
+ tracks = data[:, [0, 1, 3, 4]]
115
+ return vertices, tracks, properties, graph
116
+
117
+
118
+ def view_tracks_in_napari(
119
+ position,
120
+ population,
121
+ stack=None,
122
+ labels=None,
123
+ relabel=True,
124
+ flush_memory=True,
125
+ threads=1,
126
+ progress_callback=None,
127
+ prepare_only=False,
128
+ ):
129
+ """
130
+ Updated
131
+ """
132
+
133
+ df, df_path = get_position_table(position, population=population, return_path=True)
134
+ if df is None:
135
+ print("Please compute trajectories first... Abort...")
136
+ return None
137
+ shared_data = {
138
+ "df": df,
139
+ "path": df_path,
140
+ "position": position,
141
+ "population": population,
142
+ "selected_frame": None,
143
+ }
144
+
145
+ if (labels is not None) * relabel:
146
+ print("Replacing the cell mask labels with the track ID...")
147
+ labels = relabel_segmentation(
148
+ labels,
149
+ df,
150
+ exclude_nans=True,
151
+ threads=threads,
152
+ progress_callback=progress_callback,
153
+ )
154
+ if labels is None:
155
+ return None
156
+
157
+ vertices, tracks, properties, graph = tracks_to_napari(df, exclude_nans=True)
158
+
159
+ contrast_limits = _get_contrast_limits(stack)
160
+
161
+ data = {
162
+ "stack": stack,
163
+ "labels": labels,
164
+ "vertices": vertices,
165
+ "tracks": tracks,
166
+ "properties": properties,
167
+ "graph": graph,
168
+ "shared_data": shared_data,
169
+ "contrast_limits": contrast_limits,
170
+ "flush_memory": flush_memory,
171
+ }
172
+
173
+ if prepare_only:
174
+ return data
175
+
176
+ return launch_napari_viewer(**data)
177
+
178
+
179
+ def launch_napari_viewer(
180
+ stack,
181
+ labels,
182
+ vertices,
183
+ tracks,
184
+ properties,
185
+ graph,
186
+ shared_data,
187
+ contrast_limits,
188
+ flush_memory=True,
189
+ ):
190
+
191
+ viewer = napari.Viewer()
192
+ if stack is not None:
193
+ viewer.add_image(
194
+ stack,
195
+ channel_axis=-1,
196
+ colormap=["gray"] * stack.shape[-1],
197
+ contrast_limits=contrast_limits,
198
+ )
199
+ if labels is not None:
200
+ labels_layer = viewer.add_labels(
201
+ labels.astype(int), name="segmentation", opacity=0.4
202
+ )
203
+ viewer.add_points(vertices, size=4, name="points", opacity=0.3)
204
+ viewer.add_tracks(tracks, properties=properties, graph=graph, name="tracks")
205
+
206
+ def lock_controls(layer, widgets=(), locked=True):
207
+ qctrl = viewer.window.qt_viewer.controls.widgets[layer]
208
+ for wdg in widgets:
209
+ try:
210
+ getattr(qctrl, wdg).setEnabled(not locked)
211
+ except:
212
+ pass
213
+
214
+ label_widget_list = [
215
+ "paint_button",
216
+ "erase_button",
217
+ "fill_button",
218
+ "polygon_button",
219
+ "transform_button",
220
+ ]
221
+ lock_controls(viewer.layers["segmentation"], label_widget_list)
222
+
223
+ point_widget_list = [
224
+ "addition_button",
225
+ "delete_button",
226
+ "select_button",
227
+ "transform_button",
228
+ ]
229
+ lock_controls(viewer.layers["points"], point_widget_list)
230
+
231
+ track_widget_list = ["transform_button"]
232
+ lock_controls(viewer.layers["tracks"], track_widget_list)
233
+
234
+ # Initialize selected frame
235
+ selected_frame = viewer.dims.current_step[0]
236
+ shared_data["selected_frame"] = selected_frame
237
+
238
+ def export_modifications():
239
+
240
+ from celldetective.tracking import (
241
+ write_first_detection_class,
242
+ clean_trajectories,
243
+ )
244
+ from celldetective.utils.maths import velocity_per_track
245
+
246
+ df = shared_data["df"]
247
+ position = shared_data["position"]
248
+ population = shared_data["population"]
249
+ df = velocity_per_track(df, window_size=3, mode="bi")
250
+ df = write_first_detection_class(df, img_shape=labels[0].shape)
251
+
252
+ experiment = extract_experiment_from_position(position)
253
+ instruction_file = "/".join(
254
+ [experiment, "configs", f"tracking_instructions_{population}.json"]
255
+ )
256
+ print(f"{instruction_file=}")
257
+ if os.path.exists(instruction_file):
258
+ print("Tracking configuration file found...")
259
+ with open(instruction_file, "r") as f:
260
+ instructions = json.load(f)
261
+ if "post_processing_options" in instructions:
262
+ post_processing_options = instructions["post_processing_options"]
263
+ print(
264
+ f"Applying the following track postprocessing: {post_processing_options}..."
265
+ )
266
+ df = clean_trajectories(df.copy(), **post_processing_options)
267
+ unnamed_cols = [c for c in list(df.columns) if c.startswith("Unnamed")]
268
+ df = df.drop(unnamed_cols, axis=1)
269
+ print(f"{list(df.columns)=}")
270
+ df.to_csv(shared_data["path"], index=False)
271
+ print("Done...")
272
+
273
+ @magicgui(call_button="Export the modified\ntracks...")
274
+ def export_table_widget():
275
+ return export_modifications()
276
+
277
+ from celldetective.gui.base.styles import Styles
278
+
279
+ export_table_widget.native.setStyleSheet(Styles().button_style_sheet)
280
+
281
+ def label_changed(event):
282
+
283
+ value = viewer.layers["segmentation"].selected_label
284
+ if value != 0:
285
+ selected_frame = viewer.dims.current_step[0]
286
+ shared_data["selected_frame"] = selected_frame
287
+
288
+ viewer.layers["segmentation"].events.selected_label.connect(label_changed)
289
+
290
+ viewer.window.add_dock_widget(export_table_widget, area="right")
291
+
292
+ @labels_layer.mouse_double_click_callbacks.append
293
+ def on_second_click_of_double_click(layer, event):
294
+
295
+ df = shared_data["df"]
296
+ position = shared_data["position"]
297
+ population = shared_data["population"]
298
+
299
+ frame, x, y = event.position
300
+ try:
301
+ value_under = viewer.layers["segmentation"].data[
302
+ int(frame), int(x), int(y)
303
+ ] # labels[0,int(y),int(x)]
304
+ if value_under == 0:
305
+ return None
306
+ except:
307
+ print("Invalid mask value...")
308
+ return None
309
+
310
+ target_track_id = viewer.layers["segmentation"].selected_label
311
+
312
+ msgBox = QMessageBox()
313
+ msgBox.setIcon(QMessageBox.Question)
314
+ msgBox.setText(
315
+ f"Do you want to propagate track {target_track_id} to the cell under the mouse, track {value_under}?"
316
+ )
317
+ msgBox.setWindowTitle("Info")
318
+ msgBox.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
319
+ returnValue = msgBox.exec()
320
+ if returnValue == QMessageBox.No:
321
+ return None
322
+ else:
323
+
324
+ if target_track_id not in df[
325
+ "TRACK_ID"
326
+ ].unique() and target_track_id in np.unique(
327
+ viewer.layers["segmentation"].data[shared_data["selected_frame"]]
328
+ ):
329
+ # the selected cell in frame -1 is not in the table... we can add it to DataFrame
330
+ current_labelm1 = viewer.layers["segmentation"].data[
331
+ shared_data["selected_frame"]
332
+ ]
333
+ original_labelm1 = locate_labels(
334
+ position,
335
+ population=population,
336
+ frames=shared_data["selected_frame"],
337
+ )
338
+ original_labelm1[current_labelm1 != target_track_id] = 0
339
+ props = regionprops_table(
340
+ original_labelm1,
341
+ intensity_image=None,
342
+ properties=["centroid", "label"],
343
+ )
344
+ props = pd.DataFrame(props)
345
+ new_cell = props[["centroid-1", "centroid-0", "label"]].copy()
346
+ new_cell.rename(
347
+ columns={
348
+ "centroid-1": "POSITION_X",
349
+ "centroid-0": "POSITION_Y",
350
+ "label": "class_id",
351
+ },
352
+ inplace=True,
353
+ )
354
+ new_cell["FRAME"] = shared_data["selected_frame"]
355
+ new_cell["TRACK_ID"] = target_track_id
356
+ df = pd.concat([df, new_cell], ignore_index=True)
357
+
358
+ if value_under not in df["TRACK_ID"].unique():
359
+ # the cell to add is not currently part of DataFrame, need to add measurement
360
+
361
+ current_label = viewer.layers["segmentation"].data[int(frame)]
362
+ original_label = locate_labels(
363
+ position, population=population, frames=int(frame)
364
+ )
365
+
366
+ new_datapoint = {
367
+ "TRACK_ID": value_under,
368
+ "FRAME": frame,
369
+ "POSITION_X": np.nan,
370
+ "POSITION_Y": np.nan,
371
+ "class_id": np.nan,
372
+ }
373
+
374
+ original_label[current_label != value_under] = 0
375
+
376
+ props = regionprops_table(
377
+ original_label,
378
+ intensity_image=None,
379
+ properties=["centroid", "label"],
380
+ )
381
+ props = pd.DataFrame(props)
382
+
383
+ new_cell = props[["centroid-1", "centroid-0", "label"]].copy()
384
+ new_cell.rename(
385
+ columns={
386
+ "centroid-1": "POSITION_X",
387
+ "centroid-0": "POSITION_Y",
388
+ "label": "class_id",
389
+ },
390
+ inplace=True,
391
+ )
392
+ new_cell["FRAME"] = int(frame)
393
+ new_cell["TRACK_ID"] = value_under
394
+ df = pd.concat([df, new_cell], ignore_index=True)
395
+
396
+ relabel = np.amax(viewer.layers["segmentation"].data) + 1
397
+ for f in viewer.layers["segmentation"].data[int(frame) :]:
398
+ if target_track_id != 0:
399
+ f[np.where(f == target_track_id)] = relabel
400
+ f[np.where(f == value_under)] = target_track_id
401
+
402
+ if target_track_id != 0:
403
+ df.loc[
404
+ (df["FRAME"] >= frame) & (df["TRACK_ID"] == target_track_id),
405
+ "TRACK_ID",
406
+ ] = relabel
407
+ df.loc[
408
+ (df["FRAME"] >= frame) & (df["TRACK_ID"] == value_under), "TRACK_ID"
409
+ ] = target_track_id
410
+ df = df.loc[~(df["TRACK_ID"] == 0), :]
411
+ df = df.sort_values(by=["TRACK_ID", "FRAME"])
412
+
413
+ vertices, tracks, properties, graph = tracks_to_napari(
414
+ df, exclude_nans=True
415
+ )
416
+
417
+ viewer.layers["tracks"].data = tracks
418
+ viewer.layers["tracks"].properties = properties
419
+ viewer.layers["tracks"].graph = graph
420
+
421
+ viewer.layers["points"].data = vertices
422
+
423
+ viewer.layers["segmentation"].refresh()
424
+ viewer.layers["tracks"].refresh()
425
+ viewer.layers["points"].refresh()
426
+
427
+ shared_data["df"] = df
428
+
429
+ viewer.show(block=True)
430
+
431
+ if flush_memory:
432
+
433
+ # temporary fix for slight napari memory leak
434
+ for i in range(10000):
435
+ try:
436
+ viewer.layers.pop()
437
+ except:
438
+ pass
439
+
440
+ del viewer
441
+ del stack
442
+ del labels
443
+ gc.collect()
444
+
445
+
446
+ def load_napari_data(
447
+ position, prefix="Aligned", population="target", return_stack=True
448
+ ):
449
+ """
450
+ Load the necessary data for visualization in napari.
451
+
452
+ Parameters
453
+ ----------
454
+ position : str
455
+ The path to the position or experiment directory.
456
+ prefix : str, optional
457
+ The prefix used to identify the movie file. The default is "Aligned".
458
+ population : str, optional
459
+ The population type to load, either "target" or "effector". The default is "target".
460
+
461
+ Returns
462
+ -------
463
+ tuple
464
+ A tuple containing the loaded data, properties, graph, labels, and stack.
465
+
466
+ Examples
467
+ --------
468
+ >>> data, properties, graph, labels, stack = load_napari_data("path/to/position")
469
+ # Load the necessary data for visualization of target trajectories.
470
+
471
+ """
472
+
473
+ if not position.endswith(os.sep):
474
+ position += os.sep
475
+
476
+ position = position.replace("\\", "/")
477
+ if population.lower() == "target" or population.lower() == "targets":
478
+ if os.path.exists(
479
+ position
480
+ + os.sep.join(["output", "tables", "napari_target_trajectories.npy"])
481
+ ):
482
+ napari_data = np.load(
483
+ position
484
+ + os.sep.join(["output", "tables", "napari_target_trajectories.npy"]),
485
+ allow_pickle=True,
486
+ )
487
+ else:
488
+ napari_data = None
489
+ elif population.lower() == "effector" or population.lower() == "effectors":
490
+ if os.path.exists(
491
+ position
492
+ + os.sep.join(["output", "tables", "napari_effector_trajectories.npy"])
493
+ ):
494
+ napari_data = np.load(
495
+ position
496
+ + os.sep.join(["output", "tables", "napari_effector_trajectories.npy"]),
497
+ allow_pickle=True,
498
+ )
499
+ else:
500
+ napari_data = None
501
+ else:
502
+ if os.path.exists(
503
+ position
504
+ + os.sep.join(["output", "tables", f"napari_{population}_trajectories.npy"])
505
+ ):
506
+ napari_data = np.load(
507
+ position
508
+ + os.sep.join(
509
+ ["output", "tables", f"napari_{population}_trajectories.npy"]
510
+ ),
511
+ allow_pickle=True,
512
+ )
513
+ else:
514
+ napari_data = None
515
+
516
+ if napari_data is not None:
517
+ data = napari_data.item()["data"]
518
+ properties = napari_data.item()["properties"]
519
+ graph = napari_data.item()["graph"]
520
+ else:
521
+ data = None
522
+ properties = None
523
+ graph = None
524
+ if return_stack:
525
+ stack, labels = locate_stack_and_labels(
526
+ position, prefix=prefix, population=population
527
+ )
528
+ else:
529
+ labels = locate_labels(position, population=population)
530
+ stack = None
531
+ return data, properties, graph, labels, stack
532
+
533
+
534
+ def control_segmentation_napari(
535
+ position, prefix="Aligned", population="target", flush_memory=False
536
+ ):
537
+ """
538
+
539
+ Control the visualization of segmentation labels using the napari viewer.
540
+
541
+ Parameters
542
+ ----------
543
+ position : str
544
+ The position or directory path where the segmentation labels and stack are located.
545
+ prefix : str, optional
546
+ The prefix used to identify the stack. The default is 'Aligned'.
547
+ population : str, optional
548
+ The population type for which the segmentation is performed. The default is 'target'.
549
+ flush_memory : bool, optional
550
+ Pop napari layers upon closing the viewer to empty the memory footprint. The default is `False`.
551
+
552
+ Notes
553
+ -----
554
+ This function loads the segmentation labels and stack corresponding to the specified position and population.
555
+ It then creates a napari viewer and adds the stack and labels as layers for visualization.
556
+
557
+ Examples
558
+ --------
559
+ >>> control_segmentation_napari(position, prefix='Aligned', population="target")
560
+ # Control the visualization of segmentation labels using the napari viewer.
561
+
562
+ """
563
+
564
+ def export_labels():
565
+ labels_layer = viewer.layers["segmentation"].data
566
+ if not os.path.exists(output_folder):
567
+ os.mkdir(output_folder)
568
+
569
+ for t, im in enumerate(tqdm(labels_layer)):
570
+
571
+ try:
572
+ im = auto_correct_masks(im)
573
+ except Exception as e:
574
+ print(e)
575
+
576
+ save_tiff_imagej_compatible(
577
+ output_folder + f"{str(t).zfill(4)}.tif", im.astype(np.int16), axes="YX"
578
+ )
579
+ print("The labels have been successfully rewritten.")
580
+
581
+ def export_annotation():
582
+
583
+ # Locate experiment config
584
+ parent1 = Path(position).parent
585
+ expfolder = parent1.parent
586
+ config = PurePath(expfolder, Path("config.ini"))
587
+ expfolder = str(expfolder)
588
+ exp_name = os.path.split(expfolder)[-1]
589
+
590
+ wells = get_experiment_wells(expfolder)
591
+ well_idx = list(wells).index(str(parent1) + os.sep)
592
+
593
+ label_info = get_experiment_labels(expfolder)
594
+ metadata_info = get_experiment_metadata(expfolder)
595
+
596
+ info = {}
597
+ for k in list(label_info.keys()):
598
+ values = label_info[k]
599
+ try:
600
+ info.update({k: values[well_idx]})
601
+ except Exception as e:
602
+ print(f"{e=}")
603
+
604
+ if metadata_info is not None:
605
+ keys = list(metadata_info.keys())
606
+ for k in keys:
607
+ info.update({k: metadata_info[k]})
608
+
609
+ spatial_calibration = float(
610
+ config_section_to_dict(config, "MovieSettings")["pxtoum"]
611
+ )
612
+ channel_names, channel_indices = extract_experiment_channels(expfolder)
613
+
614
+ annotation_folder = expfolder + os.sep + f"annotations_{population}" + os.sep
615
+ if not os.path.exists(annotation_folder):
616
+ os.mkdir(annotation_folder)
617
+
618
+ print("Exporting!")
619
+ t = viewer.dims.current_step[0]
620
+ labels_layer = viewer.layers["segmentation"].data[t] # at current time
621
+
622
+ try:
623
+ labels_layer = auto_correct_masks(labels_layer)
624
+ except Exception as e:
625
+ print(e)
626
+
627
+ fov_export = True
628
+
629
+ if "Shapes" in viewer.layers:
630
+ squares = viewer.layers["Shapes"].data
631
+ test_in_frame = np.array(
632
+ [
633
+ squares[i][0, 0] == t and len(squares[i]) == 4
634
+ for i in range(len(squares))
635
+ ]
636
+ )
637
+ squares = np.array(squares)
638
+ squares = squares[test_in_frame]
639
+ nbr_squares = len(squares)
640
+ print(f"Found {nbr_squares} ROIs...")
641
+ if nbr_squares > 0:
642
+ # deactivate field of view mode
643
+ fov_export = False
644
+
645
+ for k, sq in enumerate(squares):
646
+ print(f"ROI: {sq}")
647
+ pad_to_256 = False
648
+
649
+ xmin = int(sq[0, 1])
650
+ xmax = int(sq[2, 1])
651
+ if xmax < xmin:
652
+ xmax, xmin = xmin, xmax
653
+ ymin = int(sq[0, 2])
654
+ ymax = int(sq[1, 2])
655
+ if ymax < ymin:
656
+ ymax, ymin = ymin, ymax
657
+ print(f"{xmin=};{xmax=};{ymin=};{ymax=}")
658
+ frame = viewer.layers["Image"].data[t][xmin:xmax, ymin:ymax]
659
+ if frame.shape[1] < 256 or frame.shape[0] < 256:
660
+ pad_to_256 = True
661
+ print(
662
+ "Crop too small! Padding with zeros to reach 256*256 pixels..."
663
+ )
664
+ # continue
665
+ multichannel = [frame]
666
+ for i in range(len(channel_indices) - 1):
667
+ try:
668
+ frame = viewer.layers[f"Image [{i + 1}]"].data[t][
669
+ xmin:xmax, ymin:ymax
670
+ ]
671
+ multichannel.append(frame)
672
+ except:
673
+ pass
674
+ multichannel = np.array(multichannel)
675
+ lab = labels_layer[xmin:xmax, ymin:ymax].astype(np.int16)
676
+ if pad_to_256:
677
+ shape = multichannel.shape
678
+ pad_length_x = max([0, 256 - multichannel.shape[1]])
679
+ if pad_length_x > 0 and pad_length_x % 2 == 1:
680
+ pad_length_x += 1
681
+ pad_length_y = max([0, 256 - multichannel.shape[2]])
682
+ if pad_length_y > 0 and pad_length_y % 2 == 1:
683
+ pad_length_y += 1
684
+ padded_image = np.array(
685
+ [
686
+ np.pad(
687
+ im,
688
+ (
689
+ (pad_length_x // 2, pad_length_x // 2),
690
+ (pad_length_y // 2, pad_length_y // 2),
691
+ ),
692
+ mode="constant",
693
+ )
694
+ for im in multichannel
695
+ ]
696
+ )
697
+ padded_label = np.pad(
698
+ lab,
699
+ (
700
+ (pad_length_x // 2, pad_length_x // 2),
701
+ (pad_length_y // 2, pad_length_y // 2),
702
+ ),
703
+ mode="constant",
704
+ )
705
+ lab = padded_label
706
+ multichannel = padded_image
707
+
708
+ save_tiff_imagej_compatible(
709
+ annotation_folder
710
+ + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}_labelled.tif",
711
+ lab,
712
+ axes="YX",
713
+ )
714
+ save_tiff_imagej_compatible(
715
+ annotation_folder
716
+ + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}.tif",
717
+ multichannel,
718
+ axes="CYX",
719
+ )
720
+
721
+ info.update(
722
+ {
723
+ "spatial_calibration": spatial_calibration,
724
+ "channels": list(channel_names),
725
+ "frame": t,
726
+ }
727
+ )
728
+
729
+ info_name = (
730
+ annotation_folder
731
+ + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}.json"
732
+ )
733
+ with open(info_name, "w") as f:
734
+ json.dump(info, f, indent=4)
735
+
736
+ if fov_export:
737
+ frame = viewer.layers["Image"].data[t]
738
+ multichannel = [frame]
739
+ for i in range(len(channel_indices) - 1):
740
+ try:
741
+ frame = viewer.layers[f"Image [{i + 1}]"].data[t]
742
+ multichannel.append(frame)
743
+ except:
744
+ pass
745
+ multichannel = np.array(multichannel)
746
+ save_tiff_imagej_compatible(
747
+ annotation_folder
748
+ + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_labelled.tif",
749
+ labels_layer,
750
+ axes="YX",
751
+ )
752
+ save_tiff_imagej_compatible(
753
+ annotation_folder
754
+ + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}.tif",
755
+ multichannel,
756
+ axes="CYX",
757
+ )
758
+
759
+ info.update(
760
+ {
761
+ "spatial_calibration": spatial_calibration,
762
+ "channels": list(channel_names),
763
+ "frame": t,
764
+ }
765
+ )
766
+
767
+ info_name = (
768
+ annotation_folder
769
+ + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}.json"
770
+ )
771
+ with open(info_name, "w") as f:
772
+ json.dump(info, f, indent=4)
773
+
774
+ print("Done.")
775
+
776
+ @magicgui(call_button="Save the modified labels")
777
+ def save_widget():
778
+ return export_labels()
779
+
780
+ @magicgui(call_button="Export the annotation\nof the current frame")
781
+ def export_widget():
782
+ return export_annotation()
783
+
784
+ from celldetective.gui.base.styles import Styles
785
+
786
+ stack, labels = locate_stack_and_labels(
787
+ position, prefix=prefix, population=population
788
+ )
789
+ contrast_limits = _get_contrast_limits(stack)
790
+
791
+ output_folder = position + f"labels_{population}{os.sep}"
792
+ logger.info(f"Shape of the loaded image stack: {stack.shape}...")
793
+
794
+ viewer = napari.Viewer()
795
+ try:
796
+ viewer.window._qt_window.setWindowIcon(Styles().celldetective_icon)
797
+ except Exception as e:
798
+ pass
799
+ viewer.add_image(
800
+ stack,
801
+ channel_axis=-1,
802
+ colormap=["gray"] * stack.shape[-1],
803
+ contrast_limits=contrast_limits,
804
+ )
805
+ viewer.add_labels(labels.astype(int), name="segmentation", opacity=0.4)
806
+
807
+ button_container = QWidget()
808
+ layout = QVBoxLayout(button_container)
809
+ layout.setSpacing(10)
810
+ layout.addWidget(save_widget.native)
811
+ layout.addWidget(export_widget.native)
812
+ viewer.window.add_dock_widget(button_container, area="right")
813
+
814
+ save_widget.native.setStyleSheet(Styles().button_style_sheet)
815
+ export_widget.native.setStyleSheet(Styles().button_style_sheet)
816
+
817
+ def lock_controls(layer, widgets=(), locked=True):
818
+ qctrl = viewer.window.qt_viewer.controls.widgets[layer]
819
+ for wdg in widgets:
820
+ try:
821
+ getattr(qctrl, wdg).setEnabled(not locked)
822
+ except:
823
+ pass
824
+
825
+ label_widget_list = ["polygon_button", "transform_button"]
826
+ lock_controls(viewer.layers["segmentation"], label_widget_list)
827
+
828
+ viewer.show(block=True)
829
+
830
+ if flush_memory:
831
+ # temporary fix for slight napari memory leak
832
+ for i in range(10000):
833
+ try:
834
+ viewer.layers.pop()
835
+ except:
836
+ pass
837
+
838
+ del viewer
839
+ del stack
840
+ del labels
841
+ gc.collect()
842
+
843
+ logger.info("napari viewer was successfully closed...")
844
+
845
+
846
+ def correct_annotation(filename):
847
+ """
848
+ New function to reannotate an annotation image in post, using napari and save update inplace.
849
+ """
850
+
851
+ def export_labels():
852
+ labels_layer = viewer.layers["segmentation"].data
853
+ for t, im in enumerate(tqdm(labels_layer)):
854
+
855
+ try:
856
+ im = auto_correct_masks(im)
857
+ except Exception as e:
858
+ print(e)
859
+
860
+ save_tiff_imagej_compatible(existing_lbl, im.astype(np.int16), axes="YX")
861
+ print("The labels have been successfully rewritten.")
862
+
863
+ @magicgui(call_button="Save the modified labels")
864
+ def save_widget():
865
+ return export_labels()
866
+
867
+ if filename.endswith("_labelled.tif"):
868
+ filename = filename.replace("_labelled.tif", ".tif")
869
+ if filename.endswith(".json"):
870
+ filename = filename.replace(".json", ".tif")
871
+ assert os.path.exists(filename), f"Image {filename} does not seem to exist..."
872
+
873
+ img = imread(filename.replace("\\", "/"))
874
+ if img.ndim == 3:
875
+ img = np.moveaxis(img, 0, -1)
876
+ elif img.ndim == 2:
877
+ img = img[:, :, np.newaxis]
878
+
879
+ existing_lbl = filename.replace(".tif", "_labelled.tif")
880
+ if os.path.exists(existing_lbl):
881
+ labels = imread(existing_lbl)[np.newaxis, :, :].astype(int)
882
+ else:
883
+ labels = np.zeros_like(img[:, :, 0]).astype(int)[np.newaxis, :, :]
884
+
885
+ stack = img[np.newaxis, :, :, :]
886
+ contrast_limits = _get_contrast_limits(stack)
887
+ viewer = napari.Viewer()
888
+ viewer.add_image(
889
+ stack,
890
+ channel_axis=-1,
891
+ colormap=["gray"] * stack.shape[-1],
892
+ constrast_limits=contrast_limits,
893
+ )
894
+ viewer.add_labels(labels, name="segmentation", opacity=0.4)
895
+ viewer.window.add_dock_widget(save_widget, area="right")
896
+ viewer.show(block=True)
897
+
898
+ # temporary fix for slight napari memory leak
899
+ for i in range(100):
900
+ try:
901
+ viewer.layers.pop()
902
+ except:
903
+ pass
904
+ del viewer
905
+ del stack
906
+ del labels
907
+ gc.collect()
908
+
909
+
910
+ def _view_on_napari(tracks=None, stack=None, labels=None):
911
+ """
912
+
913
+ Visualize tracks, stack, and labels using Napari.
914
+
915
+ Parameters
916
+ ----------
917
+ tracks : pandas DataFrame
918
+ DataFrame containing track information.
919
+ stack : numpy array, optional
920
+ Stack of images with shape (T, Y, X, C), where T is the number of frames, Y and X are the spatial dimensions,
921
+ and C is the number of channels. Default is None.
922
+ labels : numpy array, optional
923
+ Label stack with shape (T, Y, X) representing cell segmentations. Default is None.
924
+
925
+ Returns
926
+ -------
927
+ None
928
+
929
+ Notes
930
+ -----
931
+ This function visualizes tracks, stack, and labels using Napari, an interactive multi-dimensional image viewer.
932
+ The tracks are represented as line segments on the viewer. If a stack is provided, it is displayed as an image.
933
+ If labels are provided, they are displayed as a segmentation overlay on the stack.
934
+
935
+ Examples
936
+ --------
937
+ >>> tracks = pd.DataFrame({'track': [1, 2, 3], 'time': [1, 1, 1],
938
+ ... 'x': [10, 20, 30], 'y': [15, 25, 35]})
939
+ >>> stack = np.random.rand(100, 100, 3)
940
+ >>> labels = np.random.randint(0, 2, (100, 100))
941
+ >>> _view_on_napari(tracks, stack=stack, labels=labels)
942
+ # Visualize tracks, stack, and labels using Napari.
943
+
944
+ """
945
+
946
+ viewer = napari.Viewer()
947
+ contrast_limits = _get_contrast_limits(stack)
948
+ if stack is not None:
949
+ viewer.add_image(
950
+ stack,
951
+ channel_axis=-1,
952
+ colormap=["gray"] * stack.shape[-1],
953
+ contrast_limits=contrast_limits,
954
+ )
955
+ if labels is not None:
956
+ viewer.add_labels(labels, name="segmentation", opacity=0.4)
957
+ if tracks is not None:
958
+ viewer.add_tracks(tracks, name="tracks")
959
+ viewer.show(block=True)
960
+
961
+
962
+ def control_tracking_table(
963
+ position,
964
+ calibration=1,
965
+ prefix="Aligned",
966
+ population="target",
967
+ column_labels={
968
+ "track": "TRACK_ID",
969
+ "frame": "FRAME",
970
+ "y": "POSITION_Y",
971
+ "x": "POSITION_X",
972
+ "label": "class_id",
973
+ },
974
+ ):
975
+ """
976
+
977
+ Control the tracking table and visualize tracks using Napari.
978
+
979
+ Parameters
980
+ ----------
981
+ position : str
982
+ The position or directory of the tracking data.
983
+ calibration : float, optional
984
+ Calibration factor for converting pixel coordinates to physical units. Default is 1.
985
+ prefix : str, optional
986
+ Prefix used for the tracking data file. Default is "Aligned".
987
+ population : str, optional
988
+ Population type, either "target" or "effector". Default is "target".
989
+ column_labels : dict, optional
990
+ Dictionary containing the column labels for the tracking table. Default is
991
+ {'track': "TRACK_ID", 'frame': 'FRAME', 'y': 'POSITION_Y', 'x': 'POSITION_X', 'label': 'class_id'}.
992
+
993
+ Returns
994
+ -------
995
+ None
996
+
997
+ Notes
998
+ -----
999
+ This function loads the tracking data, applies calibration to the spatial coordinates, and visualizes the tracks
1000
+ using Napari. The tracking data is loaded from the specified `position` directory with the given `prefix` and
1001
+ `population`. The spatial coordinates (x, y) in the tracking table are divided by the `calibration` factor to
1002
+ convert them from pixel units to the specified physical units. The tracks are then visualized using Napari.
1003
+
1004
+ Examples
1005
+ --------
1006
+ >>> control_tracking_table('path/to/tracking_data', calibration=0.1, prefix='Aligned', population='target')
1007
+ # Control the tracking table and visualize tracks using Napari.
1008
+
1009
+ """
1010
+
1011
+ position = position.replace("\\", "/")
1012
+ tracks, labels, stack = load_tracking_data(
1013
+ position, prefix=prefix, population=population
1014
+ )
1015
+ tracks = tracks.loc[
1016
+ :,
1017
+ [
1018
+ column_labels["track"],
1019
+ column_labels["frame"],
1020
+ column_labels["y"],
1021
+ column_labels["x"],
1022
+ ],
1023
+ ].to_numpy()
1024
+ tracks[:, -2:] /= calibration
1025
+ _view_on_napari(tracks, labels=labels, stack=stack)