celldetective 1.5.0b0__py3-none-any.whl → 1.5.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. celldetective/_version.py +1 -1
  2. celldetective/gui/InitWindow.py +26 -4
  3. celldetective/gui/base/components.py +20 -1
  4. celldetective/gui/base_annotator.py +11 -7
  5. celldetective/gui/event_annotator.py +51 -1060
  6. celldetective/gui/interactions_block.py +55 -25
  7. celldetective/gui/interactive_timeseries_viewer.py +11 -1
  8. celldetective/gui/measure_annotator.py +968 -0
  9. celldetective/gui/process_block.py +88 -34
  10. celldetective/gui/viewers/base_viewer.py +134 -3
  11. celldetective/gui/viewers/contour_viewer.py +4 -4
  12. celldetective/gui/workers.py +124 -10
  13. celldetective/measure.py +3 -0
  14. celldetective/napari/utils.py +29 -19
  15. celldetective/processes/downloader.py +122 -96
  16. celldetective/processes/load_table.py +55 -0
  17. celldetective/processes/measure_cells.py +107 -81
  18. celldetective/processes/track_cells.py +39 -39
  19. celldetective/segmentation.py +1 -1
  20. celldetective/tracking.py +9 -0
  21. celldetective/utils/data_loaders.py +21 -1
  22. celldetective/utils/downloaders.py +38 -0
  23. celldetective/utils/maths.py +14 -1
  24. {celldetective-1.5.0b0.dist-info → celldetective-1.5.0b2.dist-info}/METADATA +1 -1
  25. {celldetective-1.5.0b0.dist-info → celldetective-1.5.0b2.dist-info}/RECORD +29 -27
  26. {celldetective-1.5.0b0.dist-info → celldetective-1.5.0b2.dist-info}/WHEEL +0 -0
  27. {celldetective-1.5.0b0.dist-info → celldetective-1.5.0b2.dist-info}/entry_points.txt +0 -0
  28. {celldetective-1.5.0b0.dist-info → celldetective-1.5.0b2.dist-info}/licenses/LICENSE +0 -0
  29. {celldetective-1.5.0b0.dist-info → celldetective-1.5.0b2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,968 @@
1
+ from PyQt5.QtWidgets import (
2
+ QApplication,
3
+ QHBoxLayout,
4
+ QVBoxLayout,
5
+ QLabel,
6
+ QLineEdit,
7
+ QPushButton,
8
+ QMessageBox,
9
+ QSlider,
10
+ QComboBox,
11
+ )
12
+ from PyQt5.QtCore import Qt, QSize
13
+ from PyQt5.QtGui import QIntValidator, QKeySequence
14
+ import numpy as np
15
+ import pandas as pd
16
+ import os
17
+ import matplotlib.pyplot as plt
18
+ from matplotlib.cm import tab10
19
+ from superqt import QLabeledDoubleSlider
20
+ from superqt.fonticon import icon
21
+ from fonticon_mdi6 import MDI6
22
+
23
+ from celldetective.gui.base_annotator import BaseAnnotator
24
+ from celldetective.gui.viewers.contour_viewer import CellEdgeVisualizer
25
+ from celldetective.gui.base.components import CelldetectiveWidget
26
+ from celldetective.gui.gui_utils import color_from_state, color_from_class
27
+ from celldetective.utils.image_loaders import locate_labels, load_frames
28
+ from celldetective.utils.masks import contour_of_instance_segmentation
29
+ from celldetective.gui.base.figure_canvas import FigureCanvas
30
+ from celldetective.gui.base.utils import center_window
31
+ from celldetective import get_logger
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ class AnnotatorStackVisualizer(CellEdgeVisualizer):
37
+ def __init__(self, *args, **kwargs):
38
+ self.scat_markers = None
39
+ super().__init__(*args, **kwargs)
40
+ self.compact_layout()
41
+
42
+ def generate_figure_canvas(self):
43
+ super().generate_figure_canvas()
44
+ self.generate_custom_overlays()
45
+ # Force layout update
46
+ self.compact_layout()
47
+
48
+ def generate_custom_overlays(self):
49
+ """Initialize scatter artists."""
50
+ self.scat_markers = self.ax.scatter([], [], color="tab:red", picker=True)
51
+ # CellEdgeVisualizer handles self.im_mask
52
+
53
+ def update_overlays(self, positions, colors):
54
+ # Update Scatter
55
+ self.scat_markers.set_offsets(positions)
56
+ if len(positions) > 0 and len(colors) > 0:
57
+ # colors should be an array of colors
58
+ self.scat_markers.set_edgecolors(colors)
59
+ self.scat_markers.set_facecolors("none")
60
+ self.scat_markers.set_picker(10)
61
+ self.scat_markers.set_linewidths(2)
62
+ self.scat_markers.set_sizes([200] * len(positions))
63
+
64
+ self.canvas.canvas.draw_idle()
65
+
66
+ def generate_edge_slider(self):
67
+ # Override to hide edge slider
68
+ pass
69
+
70
+ def generate_opacity_slider(self):
71
+ # Compact opacity slider
72
+ self.opacity_slider = QLabeledDoubleSlider()
73
+ self.opacity_slider.setOrientation(Qt.Horizontal)
74
+ self.opacity_slider.setRange(0, 1)
75
+ self.opacity_slider.setValue(self.mask_alpha)
76
+ self.opacity_slider.setDecimals(2)
77
+ self.opacity_slider.valueChanged.connect(self.change_mask_opacity)
78
+
79
+ layout = QHBoxLayout()
80
+ layout.setContentsMargins(0, 0, 0, 0)
81
+ layout.setSpacing(0)
82
+ layout.addWidget(QLabel("Opacity:"), 15)
83
+ layout.addWidget(self.opacity_slider, 85)
84
+ self.canvas.layout.addLayout(layout)
85
+
86
+ def compact_layout(self):
87
+ # Reduce margins/spacing for all slider layouts in canvas
88
+ self.canvas.layout.setSpacing(0)
89
+ self.canvas.layout.setContentsMargins(0, 5, 0, 5)
90
+ for i in range(self.canvas.layout.count()):
91
+ item = self.canvas.layout.itemAt(i)
92
+ if item.layout():
93
+ item.layout().setContentsMargins(0, 0, 0, 0)
94
+ item.layout().setSpacing(0)
95
+ elif item.widget():
96
+ # If there are direct widgets, ensure they are compact too if possible
97
+ pass
98
+
99
+
100
+ class MeasureAnnotator(BaseAnnotator):
101
+
102
+ def __init__(self, *args, **kwargs):
103
+
104
+ super().__init__(read_config=False, *args, **kwargs)
105
+
106
+ self.setWindowTitle("Static annotator")
107
+
108
+ self.int_validator = QIntValidator()
109
+ self.current_alpha = 0.5
110
+ self.value_magnitude = 1
111
+
112
+ epsilon = 0.01
113
+ self.observed_min_intensity = 0
114
+ self.observed_max_intensity = 0 + epsilon
115
+
116
+ self.current_frame = 0
117
+ self.show_fliers = False
118
+ self.status_name = "group"
119
+
120
+ if self.proceed:
121
+
122
+ from celldetective.utils.image_loaders import fix_missing_labels
123
+
124
+ # Ensure labels match stack length
125
+ if self.len_movie > 0:
126
+ temp_labels = locate_labels(self.pos, population=self.mode)
127
+ if temp_labels is None or len(temp_labels) < self.len_movie:
128
+ fix_missing_labels(self.pos, population=self.mode)
129
+ self.labels = locate_labels(self.pos, population=self.mode)
130
+ elif len(temp_labels) > self.len_movie:
131
+ self.labels = temp_labels[: self.len_movie]
132
+ else:
133
+ self.labels = temp_labels
134
+ else:
135
+ self.labels = locate_labels(self.pos, population=self.mode)
136
+
137
+ self.current_channel = 0
138
+ self.frame_lbl = QLabel("position: ")
139
+
140
+ # self.static_image() # Replaced by StackVisualizer initialization in populate_window
141
+
142
+ self.populate_window()
143
+ self.changed_class()
144
+
145
+ self.previous_index = None
146
+
147
+ else:
148
+ self.close()
149
+
150
+ def locate_tracks(self):
151
+ """
152
+ Locate the tracks.
153
+ """
154
+
155
+ if not os.path.exists(self.trajectories_path):
156
+
157
+ msgBox = QMessageBox()
158
+ msgBox.setIcon(QMessageBox.Warning)
159
+ msgBox.setText("The trajectories cannot be detected.")
160
+ msgBox.setWindowTitle("Warning")
161
+ msgBox.setStandardButtons(QMessageBox.Ok)
162
+ returnValue = msgBox.exec()
163
+ if returnValue == QMessageBox.Yes:
164
+ self.close()
165
+ else:
166
+
167
+ # Load and prep tracks
168
+ self.df_tracks = pd.read_csv(self.trajectories_path)
169
+ if "TRACK_ID" in self.df_tracks.columns:
170
+ self.df_tracks = self.df_tracks.sort_values(by=["TRACK_ID", "FRAME"])
171
+ else:
172
+ self.df_tracks = self.df_tracks.sort_values(by=["ID", "FRAME"])
173
+
174
+ cols = np.array(self.df_tracks.columns)
175
+ self.class_cols = np.array(
176
+ [
177
+ c.startswith("group") or c.startswith("class")
178
+ for c in list(self.df_tracks.columns)
179
+ ]
180
+ )
181
+ self.class_cols = list(cols[self.class_cols])
182
+
183
+ to_remove = ["class_id", "group_color", "class_color"]
184
+ for col in to_remove:
185
+ try:
186
+ self.class_cols.remove(col)
187
+ except:
188
+ pass
189
+
190
+ if len(self.class_cols) > 0:
191
+ self.status_name = self.class_cols[0]
192
+ else:
193
+ self.status_name = "group"
194
+
195
+ if self.status_name not in self.df_tracks.columns:
196
+ # only create the status column if it does not exist to not erase static classification results
197
+ self.make_status_column()
198
+ else:
199
+ # all good, do nothing
200
+ pass
201
+
202
+ all_states = self.df_tracks.loc[:, self.status_name].tolist()
203
+ all_states = np.array(all_states)
204
+ self.state_color_map = color_from_state(all_states, recently_modified=False)
205
+ self.df_tracks["group_color"] = self.df_tracks[self.status_name].apply(
206
+ self.assign_color_state
207
+ )
208
+
209
+ self.df_tracks = self.df_tracks.dropna(subset=["POSITION_X", "POSITION_Y"])
210
+ self.df_tracks["x_anim"] = self.df_tracks["POSITION_X"]
211
+ self.df_tracks["y_anim"] = self.df_tracks["POSITION_Y"]
212
+ self.df_tracks["x_anim"] = self.df_tracks["x_anim"].astype(int)
213
+ self.df_tracks["y_anim"] = self.df_tracks["y_anim"].astype(int)
214
+
215
+ self.extract_scatter_from_trajectories()
216
+ if "TRACK_ID" in self.df_tracks.columns:
217
+ self.track_of_interest = self.df_tracks.dropna(subset="TRACK_ID")[
218
+ "TRACK_ID"
219
+ ].min()
220
+ else:
221
+ self.track_of_interest = self.df_tracks.dropna(subset="ID")["ID"].min()
222
+
223
+ self.loc_t = []
224
+ self.loc_idx = []
225
+ for t in range(len(self.tracks)):
226
+ indices = np.where(self.tracks[t] == self.track_of_interest)[0]
227
+ if len(indices) > 0:
228
+ self.loc_t.append(t)
229
+ self.loc_idx.append(indices[0])
230
+
231
+ from sklearn.preprocessing import MinMaxScaler
232
+
233
+ self.MinMaxScaler = MinMaxScaler()
234
+ self.columns_to_rescale = list(self.df_tracks.columns)
235
+
236
+ cols_to_remove = [
237
+ "group",
238
+ "group_color",
239
+ "status",
240
+ "status_color",
241
+ "class_color",
242
+ "TRACK_ID",
243
+ "FRAME",
244
+ "x_anim",
245
+ "y_anim",
246
+ "t",
247
+ "dummy",
248
+ "group_color",
249
+ "state",
250
+ "generation",
251
+ "root",
252
+ "parent",
253
+ "class_id",
254
+ "class",
255
+ "t0",
256
+ "POSITION_X",
257
+ "POSITION_Y",
258
+ "position",
259
+ "well",
260
+ "well_index",
261
+ "well_name",
262
+ "pos_name",
263
+ "index",
264
+ "concentration",
265
+ "cell_type",
266
+ "antibody",
267
+ "pharmaceutical_agent",
268
+ "ID",
269
+ ] + self.class_cols
270
+
271
+ from celldetective.utils.experiment import (
272
+ get_experiment_metadata,
273
+ get_experiment_labels,
274
+ )
275
+
276
+ meta = get_experiment_metadata(self.exp_dir)
277
+ if meta is not None:
278
+ keys = list(meta.keys())
279
+ cols_to_remove.extend(keys)
280
+
281
+ labels = get_experiment_labels(self.exp_dir)
282
+ if labels is not None:
283
+ keys = list(labels.keys())
284
+ cols_to_remove.extend(labels)
285
+
286
+ for tr in cols_to_remove:
287
+ try:
288
+ self.columns_to_rescale.remove(tr)
289
+ except:
290
+ pass
291
+
292
+ x = self.df_tracks[self.columns_to_rescale].values
293
+ self.MinMaxScaler.fit(x)
294
+
295
+ def populate_options_layout(self):
296
+ # clear options hbox
297
+ for i in reversed(range(self.options_hbox.count())):
298
+ self.options_hbox.itemAt(i).widget().setParent(None)
299
+
300
+ time_option_hbox = QHBoxLayout()
301
+ time_option_hbox.setContentsMargins(100, 0, 100, 0)
302
+ time_option_hbox.setSpacing(0)
303
+
304
+ self.time_of_interest_label = QLabel("phenotype: ")
305
+ time_option_hbox.addWidget(self.time_of_interest_label, 30)
306
+
307
+ self.time_of_interest_le = QLineEdit()
308
+ self.time_of_interest_le.setValidator(self.int_validator)
309
+ time_option_hbox.addWidget(self.time_of_interest_le)
310
+
311
+ self.suppr_btn = QPushButton("")
312
+ self.suppr_btn.setStyleSheet(self.button_select_all)
313
+ self.suppr_btn.setIcon(icon(MDI6.delete, color="black"))
314
+ self.suppr_btn.setToolTip("Delete cell")
315
+ self.suppr_btn.setIconSize(QSize(20, 20))
316
+ self.suppr_btn.clicked.connect(self.del_cell)
317
+ time_option_hbox.addWidget(self.suppr_btn)
318
+
319
+ self.options_hbox.addLayout(time_option_hbox)
320
+
321
+ def update_widgets(self):
322
+
323
+ self.class_label.setText("characteristic \n group: ")
324
+ self.update_class_cb()
325
+ self.add_class_btn.setToolTip("Add a new characteristic group")
326
+ self.del_class_btn.setToolTip("Delete a characteristic group")
327
+
328
+ self.export_btn.disconnect()
329
+ self.export_btn.clicked.connect(self.export_measurements)
330
+
331
+ def update_class_cb(self):
332
+
333
+ self.class_choice_cb.disconnect()
334
+ self.class_choice_cb.clear()
335
+ cols = np.array(self.df_tracks.columns)
336
+ self.class_cols = np.array(
337
+ [
338
+ c.startswith("group") or c.startswith("status")
339
+ for c in list(self.df_tracks.columns)
340
+ ]
341
+ )
342
+ self.class_cols = list(cols[self.class_cols])
343
+
344
+ to_remove = [
345
+ "group_id",
346
+ "group_color",
347
+ "class_id",
348
+ "class_color",
349
+ "status_color",
350
+ ]
351
+ for col in to_remove:
352
+ try:
353
+ self.class_cols.remove(col)
354
+ except Exception:
355
+ pass
356
+
357
+ self.class_choice_cb.addItems(self.class_cols)
358
+ self.class_choice_cb.currentIndexChanged.connect(self.changed_class)
359
+
360
+ def populate_window(self):
361
+
362
+ super().populate_window()
363
+ # Left panel updates
364
+ self.populate_options_layout()
365
+ self.update_widgets()
366
+
367
+ self.annotation_btns_to_hide = [
368
+ self.time_of_interest_label,
369
+ self.time_of_interest_le,
370
+ self.suppr_btn,
371
+ ]
372
+ self.hide_annotation_buttons()
373
+
374
+ # Right panel - Initialize StackVisualizer
375
+ self.viewer = AnnotatorStackVisualizer(
376
+ stack_path=self.stack_path,
377
+ labels=self.labels,
378
+ frame_slider=True,
379
+ channel_cb=True,
380
+ channel_names=self.channel_names,
381
+ n_channels=self.nbr_channels,
382
+ target_channel=0,
383
+ window_title="Stack Viewer",
384
+ )
385
+
386
+ # Connect viewer signals
387
+ self.viewer.frame_slider.valueChanged.connect(self.sync_frame)
388
+ self.viewer.channel_cb.currentIndexChanged.connect(self.plot_signals)
389
+
390
+ # Connect mpl event
391
+ self.cid_pick = self.viewer.fig.canvas.mpl_connect(
392
+ "pick_event", self.on_scatter_pick
393
+ )
394
+
395
+ self.right_panel.addWidget(self.viewer.canvas)
396
+
397
+ # Force start at frame 0
398
+ self.viewer.frame_slider.setValue(0)
399
+
400
+ self.plot_signals()
401
+ self.compact_layout_main()
402
+
403
+ def compact_layout_main(self):
404
+ # Attempt to compact the viewer layout one more time from the main window side
405
+ if hasattr(self, "viewer"):
406
+ self.viewer.compact_layout()
407
+
408
+ def sync_frame(self, value):
409
+ """Callback when StackVisualizer frame changes"""
410
+
411
+ self.current_frame = value
412
+ self.update_frame_logic()
413
+
414
+ def plot_signals(self):
415
+ """Delegate signal plotting but check for viewer availability"""
416
+ if not hasattr(self, "viewer"):
417
+ return
418
+
419
+ # Call the original plot_signals logic or adapt it.
420
+ # Since plot_signals uses self.cell_ax (left panel), it should be fine.
421
+ # However, we need to ensure it uses the correct current_frame.
422
+
423
+ current_frame = self.current_frame
424
+
425
+ yvalues = []
426
+ all_yvalues = []
427
+ current_yvalues = []
428
+ labels = []
429
+ range_values = []
430
+
431
+ for i in range(len(self.signal_choice_cb)):
432
+ signal_choice = self.signal_choice_cb[i].currentText()
433
+
434
+ if signal_choice != "--":
435
+ if "TRACK_ID" in self.df_tracks.columns:
436
+ ydata = self.df_tracks.loc[
437
+ (self.df_tracks["TRACK_ID"] == self.track_of_interest)
438
+ & (self.df_tracks["FRAME"] == current_frame),
439
+ signal_choice,
440
+ ].to_numpy()
441
+ else:
442
+ ydata = self.df_tracks.loc[
443
+ (self.df_tracks["ID"] == self.track_of_interest), signal_choice
444
+ ].to_numpy()
445
+ all_ydata = self.df_tracks.loc[:, signal_choice].to_numpy()
446
+ ydataNaN = ydata
447
+ ydata = ydata[ydata == ydata] # remove nan
448
+
449
+ current_ydata = self.df_tracks.loc[
450
+ (self.df_tracks["FRAME"] == current_frame), signal_choice
451
+ ].to_numpy()
452
+ current_ydata = current_ydata[current_ydata == current_ydata]
453
+ all_ydata = all_ydata[all_ydata == all_ydata]
454
+ yvalues.extend(ydataNaN)
455
+ current_yvalues.append(current_ydata)
456
+ all_yvalues.append(all_ydata)
457
+ range_values.extend(all_ydata)
458
+ labels.append(signal_choice)
459
+
460
+ self.cell_ax.clear()
461
+ if self.log_scale:
462
+ self.cell_ax.set_yscale("log")
463
+ else:
464
+ self.cell_ax.set_yscale("linear")
465
+
466
+ if len(yvalues) > 0:
467
+ try:
468
+ self.cell_ax.boxplot(all_yvalues, showfliers=self.show_fliers)
469
+ except Exception as e:
470
+ logger.error(f"{e=}")
471
+
472
+ x_pos = np.arange(len(all_yvalues)) + 1
473
+ for index, feature in enumerate(current_yvalues):
474
+ x_values_strip = (index + 1) + np.random.normal(
475
+ 0, 0.04, size=len(feature)
476
+ )
477
+ self.cell_ax.plot(
478
+ x_values_strip,
479
+ feature,
480
+ marker="o",
481
+ linestyle="None",
482
+ color=tab10.colors[0],
483
+ alpha=0.1,
484
+ )
485
+ self.cell_ax.plot(
486
+ x_pos,
487
+ yvalues,
488
+ marker="H",
489
+ linestyle="None",
490
+ color=tab10.colors[3],
491
+ alpha=1,
492
+ )
493
+ range_values = np.array(range_values)
494
+ range_values = range_values[range_values == range_values]
495
+
496
+ # Filter out non-positive values if log scale is active to prevent warnings
497
+ if self.log_scale:
498
+ range_values = range_values[range_values > 0]
499
+
500
+ if len(range_values) > 0:
501
+ self.value_magnitude = np.nanmin(range_values) - 0.03 * (
502
+ np.nanmax(range_values) - np.nanmin(range_values)
503
+ )
504
+ else:
505
+ self.value_magnitude = 1
506
+
507
+ self.non_log_ymin = np.nanmin(range_values) - 0.03 * (
508
+ np.nanmax(range_values) - np.nanmin(range_values)
509
+ )
510
+ self.non_log_ymax = np.nanmax(range_values) + 0.03 * (
511
+ np.nanmax(range_values) - np.nanmin(range_values)
512
+ )
513
+ if self.cell_ax.get_yscale() == "linear":
514
+ self.cell_ax.set_ylim(self.non_log_ymin, self.non_log_ymax)
515
+ else:
516
+ self.cell_ax.set_ylim(self.value_magnitude, self.non_log_ymax)
517
+ else:
518
+ self.cell_ax.text(
519
+ 0.5,
520
+ 0.5,
521
+ "No data available",
522
+ horizontalalignment="center",
523
+ verticalalignment="center",
524
+ transform=self.cell_ax.transAxes,
525
+ )
526
+
527
+ self.cell_fcanvas.canvas.draw()
528
+
529
+ def plot_red_points(self, ax):
530
+ yvalues = []
531
+ current_frame = self.current_frame
532
+ for i in range(len(self.signal_choice_cb)):
533
+ signal_choice = self.signal_choice_cb[i].currentText()
534
+ if signal_choice != "--":
535
+ if "TRACK_ID" in self.df_tracks.columns:
536
+ ydata = self.df_tracks.loc[
537
+ (self.df_tracks["TRACK_ID"] == self.track_of_interest)
538
+ & (self.df_tracks["FRAME"] == current_frame),
539
+ signal_choice,
540
+ ].to_numpy()
541
+ else:
542
+ ydata = self.df_tracks.loc[
543
+ (self.df_tracks["ID"] == self.track_of_interest)
544
+ & (self.df_tracks["FRAME"] == current_frame),
545
+ signal_choice,
546
+ ].to_numpy()
547
+ ydata = ydata[ydata == ydata] # remove nan
548
+ yvalues.extend(ydata)
549
+ x_pos = np.arange(len(yvalues)) + 1
550
+ ax.plot(
551
+ x_pos, yvalues, marker="H", linestyle="None", color=tab10.colors[3], alpha=1
552
+ ) # Plot red points representing cells
553
+ self.cell_fcanvas.canvas.draw()
554
+
555
+ def select_single_cell(self, index, timepoint):
556
+
557
+ self.correct_btn.setEnabled(True)
558
+ self.cancel_btn.setEnabled(True)
559
+ self.del_shortcut.setEnabled(True)
560
+
561
+ self.track_of_interest = self.tracks[timepoint][index]
562
+ logger.info(f"You selected cell #{self.track_of_interest}...")
563
+ self.give_cell_information()
564
+
565
+ if len(self.cell_ax.lines) > 0:
566
+ self.cell_ax.lines[
567
+ -1
568
+ ].remove() # Remove the last line (red points) from the plot
569
+ self.plot_red_points(self.cell_ax)
570
+ else:
571
+ self.plot_signals()
572
+
573
+ self.loc_t = []
574
+ self.loc_idx = []
575
+ for t in range(len(self.tracks)):
576
+ indices = np.where(self.tracks[t] == self.track_of_interest)[0]
577
+ if len(indices) > 0:
578
+ self.loc_t.append(t)
579
+ self.loc_idx.append(indices[0])
580
+
581
+ self.previous_color = []
582
+ for t, idx in zip(self.loc_t, self.loc_idx):
583
+ self.previous_color.append(self.colors[t][idx].copy())
584
+ self.colors[t][idx] = "lime"
585
+
586
+ self.draw_frame(self.current_frame)
587
+
588
+ def cancel_selection(self):
589
+ super().cancel_selection()
590
+ self.event = None
591
+ self.draw_frame(self.current_frame)
592
+
593
+ def export_measurements(self):
594
+ logger.info("User interactions: Exporting measurements...")
595
+ # Implementation same as before
596
+ auto_dataset_name = (
597
+ self.pos.split(os.sep)[-4]
598
+ + "_"
599
+ + self.pos.split(os.sep)[-2]
600
+ + f"_{str(self.current_frame).zfill(3)}"
601
+ + f"_{self.status_name}.npy"
602
+ )
603
+
604
+ if self.normalized_signals:
605
+ self.normalize_features_btn.click()
606
+
607
+ subdf = self.df_tracks.loc[self.df_tracks["FRAME"] == self.current_frame, :]
608
+ subdf["class"] = subdf[self.status_name]
609
+ dico = subdf.to_dict("records")
610
+
611
+ pathsave = QFileDialog.getSaveFileName(
612
+ self, "Select file name", self.exp_dir + auto_dataset_name, ".npy"
613
+ )[0]
614
+ if pathsave != "":
615
+ if not pathsave.endswith(".npy"):
616
+ pathsave += ".npy"
617
+ try:
618
+ np.save(pathsave, dico)
619
+ logger.info(f"File successfully written in {pathsave}.")
620
+ except Exception as e:
621
+ logger.error(f"Error {e}...")
622
+
623
+ def write_new_event_class(self):
624
+
625
+ if self.class_name_le.text() == "":
626
+ self.target_class = "group"
627
+ else:
628
+ self.target_class = "group_" + self.class_name_le.text()
629
+
630
+ logger.info(
631
+ f"User interactions: Creating new characteristic group '{self.target_class}'"
632
+ )
633
+
634
+ if self.target_class in list(self.df_tracks.columns):
635
+ msgBox = QMessageBox()
636
+ msgBox.setIcon(QMessageBox.Warning)
637
+ msgBox.setText(
638
+ "This characteristic group name already exists. If you proceed,\nall annotated data will be rewritten. Do you wish to continue?"
639
+ )
640
+ msgBox.setWindowTitle("Warning")
641
+ msgBox.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
642
+ returnValue = msgBox.exec()
643
+ if returnValue == QMessageBox.No:
644
+ return None
645
+ else:
646
+ pass
647
+
648
+ self.df_tracks.loc[:, self.target_class] = 0
649
+
650
+ self.update_class_cb()
651
+
652
+ idx = self.class_choice_cb.findText(self.target_class)
653
+ self.status_name = self.target_class
654
+ self.class_choice_cb.setCurrentIndex(idx)
655
+ self.newClassWidget.close()
656
+
657
+ def hide_annotation_buttons(self):
658
+
659
+ for a in self.annotation_btns_to_hide:
660
+ a.hide()
661
+ self.time_of_interest_label.setEnabled(False)
662
+ self.time_of_interest_le.setText("")
663
+ self.time_of_interest_le.setEnabled(False)
664
+
665
+ def show_annotation_buttons(self):
666
+
667
+ for a in self.annotation_btns_to_hide:
668
+ a.show()
669
+
670
+ self.time_of_interest_label.setEnabled(True)
671
+ self.time_of_interest_le.setEnabled(True)
672
+ self.correct_btn.setText("submit")
673
+
674
+ self.correct_btn.disconnect()
675
+ self.correct_btn.clicked.connect(self.apply_modification)
676
+
677
+ def give_cell_information(self):
678
+
679
+ try:
680
+ cell_selected = f"cell: {self.track_of_interest}\n"
681
+ if "TRACK_ID" in self.df_tracks.columns:
682
+ cell_status = f"phenotype: {self.df_tracks.loc[(self.df_tracks['FRAME']==self.current_frame)&(self.df_tracks['TRACK_ID'] == self.track_of_interest), self.status_name].to_numpy()[0]}\n"
683
+ else:
684
+ cell_status = f"phenotype: {self.df_tracks.loc[self.df_tracks['ID'] == self.track_of_interest, self.status_name].to_numpy()[0]}\n"
685
+ self.cell_info.setText(cell_selected + cell_status)
686
+ except Exception as e:
687
+ logger.error(e)
688
+
689
+ def create_new_event_class(self):
690
+
691
+ # display qwidget to name the event
692
+ self.newClassWidget = CelldetectiveWidget()
693
+ self.newClassWidget.setWindowTitle("Create new characteristic group")
694
+
695
+ layout = QVBoxLayout()
696
+ self.newClassWidget.setLayout(layout)
697
+ name_hbox = QHBoxLayout()
698
+ name_hbox.addWidget(QLabel("group name: "), 25)
699
+ self.class_name_le = QLineEdit("group")
700
+ name_hbox.addWidget(self.class_name_le, 75)
701
+ layout.addLayout(name_hbox)
702
+
703
+ btn_hbox = QHBoxLayout()
704
+ submit_btn = QPushButton("submit")
705
+ cancel_btn = QPushButton("cancel")
706
+ btn_hbox.addWidget(cancel_btn, 50)
707
+ btn_hbox.addWidget(submit_btn, 50)
708
+ layout.addLayout(btn_hbox)
709
+
710
+ submit_btn.clicked.connect(self.write_new_event_class)
711
+ cancel_btn.clicked.connect(self.close_without_new_class)
712
+
713
+ self.newClassWidget.show()
714
+ center_window(self.newClassWidget)
715
+
716
+ def apply_modification(self):
717
+ if self.time_of_interest_le.text() != "":
718
+ status = int(self.time_of_interest_le.text())
719
+ else:
720
+ status = 0
721
+
722
+ logger.info(
723
+ f"User interactions: Reclassifying cell #{self.track_of_interest} at frame {self.current_frame} to status {status}"
724
+ )
725
+ if "TRACK_ID" in self.df_tracks.columns:
726
+ self.df_tracks.loc[
727
+ (self.df_tracks["TRACK_ID"] == self.track_of_interest)
728
+ & (self.df_tracks["FRAME"] == self.current_frame),
729
+ self.status_name,
730
+ ] = status
731
+
732
+ indices = self.df_tracks.index[
733
+ (self.df_tracks["TRACK_ID"] == self.track_of_interest)
734
+ & (self.df_tracks["FRAME"] == self.current_frame)
735
+ ]
736
+ else:
737
+ self.df_tracks.loc[
738
+ (self.df_tracks["ID"] == self.track_of_interest)
739
+ & (self.df_tracks["FRAME"] == self.current_frame),
740
+ self.status_name,
741
+ ] = status
742
+
743
+ indices = self.df_tracks.index[
744
+ (self.df_tracks["ID"] == self.track_of_interest)
745
+ & (self.df_tracks["FRAME"] == self.current_frame)
746
+ ]
747
+
748
+ self.df_tracks.loc[indices, self.status_name] = status
749
+ all_states = self.df_tracks.loc[:, self.status_name].tolist()
750
+ all_states = np.array(all_states)
751
+ self.state_color_map = color_from_state(all_states, recently_modified=False)
752
+
753
+ self.df_tracks["group_color"] = self.df_tracks[self.status_name].apply(
754
+ self.assign_color_state
755
+ )
756
+ self.extract_scatter_from_trajectories()
757
+ self.give_cell_information()
758
+
759
+ self.correct_btn.disconnect()
760
+ self.correct_btn.clicked.connect(self.show_annotation_buttons)
761
+
762
+ self.hide_annotation_buttons()
763
+ self.correct_btn.setEnabled(False)
764
+ self.correct_btn.setText("correct")
765
+ self.cancel_btn.setEnabled(False)
766
+ self.del_shortcut.setEnabled(False)
767
+
768
+ if len(self.selection) > 0:
769
+ self.selection.pop(0)
770
+
771
+ self.draw_frame(self.current_frame)
772
+
773
+ def assign_color_state(self, state):
774
+
775
+ if np.isnan(state):
776
+ state = "nan"
777
+ return self.state_color_map[state]
778
+
779
+ def on_scatter_pick(self, event):
780
+ """Handle pick event on scatter plot."""
781
+ self.event = event
782
+ ind = event.ind
783
+
784
+ if len(ind) > 1:
785
+ # disambiguate based on distance to mouse click
786
+ datax, datay = [self.positions[self.current_frame][i, 0] for i in ind], [
787
+ self.positions[self.current_frame][i, 1] for i in ind
788
+ ]
789
+ msx, msy = event.mouseevent.xdata, event.mouseevent.ydata
790
+ dist = np.sqrt((np.array(datax) - msx) ** 2 + (np.array(datay) - msy) ** 2)
791
+ ind = [ind[np.argmin(dist)]]
792
+
793
+ if len(ind) > 0:
794
+ # We have a single point
795
+ idx = ind[0]
796
+
797
+ # Enforce single selection / Toggle
798
+ if len(self.selection) > 0:
799
+ # Check if we clicked the same cell
800
+ prev_idx, prev_frame = self.selection[0]
801
+ if (
802
+ prev_idx == idx
803
+ ): # and prev_frame == self.current_frame (implicit since we pick on current frame)
804
+
805
+ self.cancel_selection()
806
+ return
807
+
808
+ self.cancel_selection()
809
+
810
+ self.selection = [[idx, self.current_frame]]
811
+ self.select_single_cell(idx, self.current_frame)
812
+
813
+ def draw_frame(self, framedata):
814
+ """
815
+ Update plot elements at each timestep of the loop.
816
+ Using StackVisualizer overlay update
817
+ """
818
+ self.framedata = framedata
819
+
820
+ # Prepare overlays
821
+ if self.framedata < len(self.positions):
822
+ pos = self.positions[self.framedata]
823
+ cols = self.colors[self.framedata][:, 0]
824
+ else:
825
+ pos = []
826
+ cols = []
827
+
828
+ # No need to manage contour cache or labels here
829
+ # CellEdgeVisualizer handles it.
830
+
831
+ # Update Viewer Scatter Only
832
+ # Note: Mask is updated automatically by CellEdgeVisualizer's change_frame
833
+ self.viewer.update_overlays(
834
+ positions=pos,
835
+ colors=cols,
836
+ )
837
+
838
+ def make_status_column(self):
839
+ if self.status_name == "state_firstdetection":
840
+ pass
841
+ else:
842
+ self.df_tracks.loc[:, self.status_name] = 0
843
+ all_states = self.df_tracks.loc[:, self.status_name].tolist()
844
+ all_states = np.array(all_states)
845
+ self.state_color_map = color_from_state(all_states, recently_modified=False)
846
+ self.df_tracks["group_color"] = self.df_tracks[self.status_name].apply(
847
+ self.assign_color_state
848
+ )
849
+
850
+ def extract_scatter_from_trajectories(self):
851
+
852
+ self.positions = []
853
+ self.colors = []
854
+ self.tracks = []
855
+
856
+ for t in np.arange(self.len_movie):
857
+ self.positions.append(
858
+ self.df_tracks.loc[
859
+ self.df_tracks["FRAME"] == t, ["POSITION_X", "POSITION_Y"]
860
+ ].to_numpy()
861
+ )
862
+ self.colors.append(
863
+ self.df_tracks.loc[
864
+ self.df_tracks["FRAME"] == t, ["group_color"]
865
+ ].to_numpy()
866
+ )
867
+ if "TRACK_ID" in self.df_tracks.columns:
868
+ self.tracks.append(
869
+ self.df_tracks.loc[
870
+ self.df_tracks["FRAME"] == t, "TRACK_ID"
871
+ ].to_numpy()
872
+ )
873
+ else:
874
+ self.tracks.append(
875
+ self.df_tracks.loc[self.df_tracks["FRAME"] == t, "ID"].to_numpy()
876
+ )
877
+
878
+ def compute_status_and_colors(self, index=0):
879
+ self.changed_class()
880
+
881
+ def changed_class(self):
882
+ self.status_name = self.class_choice_cb.currentText()
883
+ if self.status_name != "":
884
+ # self.compute_status_and_colors()
885
+ self.modify()
886
+ self.draw_frame(self.current_frame)
887
+
888
+ def update_frame_logic(self):
889
+ """
890
+ Logic to execute when frame changes.
891
+ """
892
+ # Auto-switch track of interest if ID mode
893
+ if "TRACK_ID" in list(self.df_tracks.columns):
894
+ pass
895
+ elif "ID" in list(self.df_tracks.columns):
896
+ # print("ID in cols... change class of interest... ")
897
+ candidates = self.df_tracks[self.df_tracks["FRAME"] == self.current_frame][
898
+ "ID"
899
+ ]
900
+ if not candidates.empty:
901
+ self.track_of_interest = candidates.min()
902
+ self.modify()
903
+
904
+ self.draw_frame(self.current_frame)
905
+ self.give_cell_information()
906
+ self.plot_signals()
907
+
908
+ def changed_channel(self):
909
+ """Handled by StackViewer mostly, but we might need to refresh plotting if things depend on channel"""
910
+ pass # StackViewer handles image reload
911
+
912
+ def save_trajectories(self):
913
+ logger.info(f"Saving trajectories...")
914
+ if self.normalized_signals:
915
+ self.normalize_features_btn.click()
916
+ if self.selection:
917
+ self.cancel_selection()
918
+
919
+ # Avoid crash if status doesn't exist or is special
920
+ # self.df_tracks = self.df_tracks.drop(
921
+ # self.df_tracks[self.df_tracks[self.status_name] == 99].index
922
+ # )
923
+
924
+ try:
925
+ self.df_tracks.drop(columns="", inplace=True)
926
+ except:
927
+ pass
928
+ try:
929
+ self.df_tracks.drop(columns="group_color", inplace=True)
930
+ except:
931
+ pass
932
+ try:
933
+ self.df_tracks.drop(columns="x_anim", inplace=True)
934
+ except:
935
+ pass
936
+ try:
937
+ self.df_tracks.drop(columns="y_anim", inplace=True)
938
+ except:
939
+ pass
940
+
941
+ self.df_tracks.to_csv(self.trajectories_path, index=False)
942
+ logger.info("Table successfully exported...")
943
+
944
+ self.locate_tracks()
945
+ self.changed_class()
946
+
947
+ def modify(self):
948
+ all_states = self.df_tracks.loc[:, self.status_name].tolist()
949
+ all_states = np.array(all_states)
950
+ self.state_color_map = color_from_state(all_states, recently_modified=False)
951
+
952
+ self.df_tracks["group_color"] = self.df_tracks[self.status_name].apply(
953
+ self.assign_color_state
954
+ )
955
+
956
+ self.extract_scatter_from_trajectories()
957
+ self.give_cell_information()
958
+
959
+ self.correct_btn.disconnect()
960
+ self.correct_btn.clicked.connect(self.show_annotation_buttons)
961
+
962
+ def del_cell(self):
963
+ logger.info(
964
+ f"User interactions: Deleting cell #{self.track_of_interest} (setting status to 99)"
965
+ )
966
+ self.time_of_interest_le.setEnabled(False)
967
+ self.time_of_interest_le.setText("99")
968
+ self.apply_modification()