celldetective 1.5.0b1__py3-none-any.whl → 1.5.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. celldetective/_version.py +1 -1
  2. celldetective/gui/InitWindow.py +51 -12
  3. celldetective/gui/base/components.py +22 -1
  4. celldetective/gui/base_annotator.py +20 -9
  5. celldetective/gui/control_panel.py +21 -16
  6. celldetective/gui/event_annotator.py +51 -1060
  7. celldetective/gui/gui_utils.py +14 -5
  8. celldetective/gui/interactions_block.py +55 -25
  9. celldetective/gui/interactive_timeseries_viewer.py +11 -1
  10. celldetective/gui/measure_annotator.py +1064 -0
  11. celldetective/gui/plot_measurements.py +2 -4
  12. celldetective/gui/plot_signals_ui.py +3 -4
  13. celldetective/gui/process_block.py +298 -72
  14. celldetective/gui/viewers/base_viewer.py +134 -3
  15. celldetective/gui/viewers/contour_viewer.py +4 -4
  16. celldetective/gui/workers.py +25 -10
  17. celldetective/measure.py +3 -0
  18. celldetective/napari/utils.py +29 -19
  19. celldetective/processes/load_table.py +55 -0
  20. celldetective/processes/measure_cells.py +107 -81
  21. celldetective/processes/track_cells.py +39 -39
  22. celldetective/segmentation.py +1 -1
  23. celldetective/tracking.py +9 -0
  24. celldetective/utils/data_loaders.py +21 -1
  25. celldetective/utils/image_loaders.py +3 -0
  26. celldetective/utils/masks.py +1 -1
  27. celldetective/utils/maths.py +14 -1
  28. {celldetective-1.5.0b1.dist-info → celldetective-1.5.0b3.dist-info}/METADATA +1 -1
  29. {celldetective-1.5.0b1.dist-info → celldetective-1.5.0b3.dist-info}/RECORD +35 -32
  30. tests/gui/test_enhancements.py +351 -0
  31. tests/test_notebooks.py +2 -1
  32. {celldetective-1.5.0b1.dist-info → celldetective-1.5.0b3.dist-info}/WHEEL +0 -0
  33. {celldetective-1.5.0b1.dist-info → celldetective-1.5.0b3.dist-info}/entry_points.txt +0 -0
  34. {celldetective-1.5.0b1.dist-info → celldetective-1.5.0b3.dist-info}/licenses/LICENSE +0 -0
  35. {celldetective-1.5.0b1.dist-info → celldetective-1.5.0b3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1064 @@
1
+ from PyQt5.QtWidgets import (
2
+ QApplication,
3
+ QHBoxLayout,
4
+ QVBoxLayout,
5
+ QLabel,
6
+ QLineEdit,
7
+ QPushButton,
8
+ QMessageBox,
9
+ QSlider,
10
+ QComboBox,
11
+ )
12
+ from PyQt5.QtCore import Qt, QSize
13
+ from PyQt5.QtGui import QIntValidator, QKeySequence
14
+ import numpy as np
15
+ import pandas as pd
16
+ import os
17
+ import matplotlib.pyplot as plt
18
+ from matplotlib.cm import tab10
19
+ from superqt import QLabeledDoubleSlider
20
+ from superqt.fonticon import icon
21
+ from fonticon_mdi6 import MDI6
22
+
23
+ from celldetective.gui.base_annotator import BaseAnnotator
24
+ from celldetective.gui.viewers.contour_viewer import CellEdgeVisualizer
25
+ from celldetective.gui.base.components import CelldetectiveWidget
26
+ from celldetective.gui.gui_utils import color_from_state, color_from_class
27
+ from celldetective.utils.image_loaders import locate_labels, load_frames
28
+ from celldetective.utils.masks import contour_of_instance_segmentation
29
+ from celldetective.gui.base.figure_canvas import FigureCanvas
30
+ from celldetective.gui.base.utils import center_window
31
+ from celldetective import get_logger
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ class AnnotatorStackVisualizer(CellEdgeVisualizer):
37
+ def __init__(self, *args, **kwargs):
38
+ self.scat_markers = None
39
+ super().__init__(*args, **kwargs)
40
+ self.compact_layout()
41
+
42
+ def generate_figure_canvas(self):
43
+ super().generate_figure_canvas()
44
+ self.generate_custom_overlays()
45
+ # Force layout update
46
+ self.compact_layout()
47
+
48
+ def generate_custom_overlays(self):
49
+ """Initialize scatter artists."""
50
+ self.scat_markers = self.ax.scatter([], [], color="tab:red", picker=True)
51
+ # CellEdgeVisualizer handles self.im_mask
52
+
53
+ def update_overlays(self, positions, colors):
54
+ # Update Scatter
55
+ self.scat_markers.set_offsets(positions)
56
+ if len(positions) > 0 and len(colors) > 0:
57
+ # colors should be an array of colors
58
+ self.scat_markers.set_edgecolors(colors)
59
+ self.scat_markers.set_facecolors("none")
60
+ self.scat_markers.set_picker(10)
61
+ self.scat_markers.set_linewidths(2)
62
+ self.scat_markers.set_sizes([200] * len(positions))
63
+
64
+ self.canvas.canvas.draw_idle()
65
+
66
+ def generate_edge_slider(self):
67
+ # Override to hide edge slider
68
+ pass
69
+
70
+ def generate_opacity_slider(self):
71
+ # Compact opacity slider
72
+ self.opacity_slider = QLabeledDoubleSlider()
73
+ self.opacity_slider.setOrientation(Qt.Horizontal)
74
+ self.opacity_slider.setRange(0, 1)
75
+ self.opacity_slider.setValue(self.mask_alpha)
76
+ self.opacity_slider.setDecimals(2)
77
+ self.opacity_slider.valueChanged.connect(self.change_mask_opacity)
78
+
79
+ layout = QHBoxLayout()
80
+ layout.setContentsMargins(0, 0, 0, 0)
81
+ layout.setSpacing(0)
82
+ layout.addWidget(QLabel("Opacity:"), 15)
83
+ layout.addWidget(self.opacity_slider, 85)
84
+ self.canvas.layout.addLayout(layout)
85
+
86
+ def compact_layout(self):
87
+ # Reduce margins/spacing for all slider layouts in canvas
88
+ self.canvas.layout.setSpacing(0)
89
+ self.canvas.layout.setContentsMargins(0, 5, 0, 5)
90
+ for i in range(self.canvas.layout.count()):
91
+ item = self.canvas.layout.itemAt(i)
92
+ if item.layout():
93
+ item.layout().setContentsMargins(0, 0, 0, 0)
94
+ item.layout().setSpacing(0)
95
+ elif item.widget():
96
+ # If there are direct widgets, ensure they are compact too if possible
97
+ pass
98
+
99
+
100
+ class MeasureAnnotator(BaseAnnotator):
101
+
102
+ def __init__(self, *args, **kwargs):
103
+
104
+ self.status_name = "group"
105
+ super().__init__(read_config=False, *args, **kwargs)
106
+
107
+ self.setWindowTitle("Static annotator")
108
+
109
+ self.int_validator = QIntValidator()
110
+ self.current_alpha = 0.5
111
+ self.value_magnitude = 1
112
+
113
+ epsilon = 0.01
114
+ self.observed_min_intensity = 0
115
+ self.observed_max_intensity = 0 + epsilon
116
+
117
+ self.current_frame = 0
118
+ self.show_fliers = False
119
+
120
+ if self.proceed:
121
+
122
+ from celldetective.utils.image_loaders import fix_missing_labels
123
+ from celldetective.tracking import write_first_detection_class
124
+
125
+ # Ensure labels match stack length
126
+ if self.len_movie > 0:
127
+ temp_labels = locate_labels(self.pos, population=self.mode)
128
+ if temp_labels is None or len(temp_labels) < self.len_movie:
129
+ fix_missing_labels(
130
+ self.pos,
131
+ population=self.mode,
132
+ prefix=self.parent_window.movie_prefix,
133
+ )
134
+ self.labels = locate_labels(self.pos, population=self.mode)
135
+ elif len(temp_labels) > self.len_movie:
136
+ self.labels = temp_labels[: self.len_movie]
137
+ else:
138
+ self.labels = temp_labels
139
+ else:
140
+ self.labels = locate_labels(self.pos, population=self.mode)
141
+
142
+ self.current_channel = 0
143
+ self.frame_lbl = QLabel("position: ")
144
+
145
+ # self.static_image() # Replaced by StackVisualizer initialization in populate_window
146
+
147
+ self.populate_window()
148
+ self.changed_class()
149
+
150
+ self.previous_index = None
151
+
152
+ else:
153
+ self.close()
154
+
155
+ def locate_tracks(self):
156
+ """
157
+ Locate the tracks.
158
+ """
159
+
160
+ if not os.path.exists(self.trajectories_path):
161
+
162
+ msgBox = QMessageBox()
163
+ msgBox.setIcon(QMessageBox.Warning)
164
+ msgBox.setText("The trajectories cannot be detected.")
165
+ msgBox.setWindowTitle("Warning")
166
+ msgBox.setStandardButtons(QMessageBox.Ok)
167
+ returnValue = msgBox.exec()
168
+ if returnValue == QMessageBox.Yes:
169
+ self.close()
170
+ else:
171
+
172
+ # Load and prep tracks
173
+ self.df_tracks = pd.read_csv(self.trajectories_path)
174
+ if "TRACK_ID" in self.df_tracks.columns:
175
+ self.df_tracks = self.df_tracks.sort_values(by=["TRACK_ID", "FRAME"])
176
+ else:
177
+ self.df_tracks = self.df_tracks.sort_values(by=["ID", "FRAME"])
178
+
179
+ cols = np.array(self.df_tracks.columns)
180
+ self.class_cols = np.array(
181
+ [
182
+ c.startswith("group")
183
+ or c.startswith("class")
184
+ or c.startswith("status")
185
+ for c in list(self.df_tracks.columns)
186
+ ]
187
+ )
188
+ self.class_cols = list(cols[self.class_cols])
189
+
190
+ to_remove = [
191
+ "class_id",
192
+ "group_color",
193
+ "class_color",
194
+ "group_id",
195
+ "status_color",
196
+ "status_id",
197
+ ]
198
+ for col in to_remove:
199
+ try:
200
+ self.class_cols.remove(col)
201
+ except:
202
+ pass
203
+
204
+ # Generate missing status columns from class columns
205
+ for c in self.class_cols:
206
+ if c.startswith("class_"):
207
+ status_col = c.replace("class_", "status_")
208
+ if status_col not in self.df_tracks.columns:
209
+ if (
210
+ status_col == "status_firstdetection"
211
+ or c == "class_firstdetection"
212
+ ):
213
+ try:
214
+ from celldetective.tracking import (
215
+ write_first_detection_class,
216
+ )
217
+
218
+ self.df_tracks = write_first_detection_class(
219
+ self.df_tracks
220
+ )
221
+ except Exception as e:
222
+ logger.error(
223
+ f"Could not generate status_firstdetection: {e}"
224
+ )
225
+ self.df_tracks[status_col] = self.df_tracks[c]
226
+ else:
227
+ self.df_tracks[status_col] = self.df_tracks[c]
228
+
229
+ # Re-evaluate class_cols after generation
230
+ cols = np.array(self.df_tracks.columns)
231
+ self.class_cols = np.array(
232
+ [
233
+ c.startswith("group") or c.startswith("status")
234
+ for c in list(self.df_tracks.columns)
235
+ ]
236
+ )
237
+ self.class_cols = list(cols[self.class_cols])
238
+ for col in to_remove:
239
+ try:
240
+ self.class_cols.remove(col)
241
+ except:
242
+ pass
243
+
244
+ if len(self.class_cols) > 0:
245
+ if self.status_name not in self.class_cols:
246
+ self.status_name = self.class_cols[0]
247
+ else:
248
+ self.status_name = "group"
249
+
250
+ if self.status_name not in self.df_tracks.columns:
251
+ # only create the status column if it does not exist to not erase static classification results
252
+ self.make_status_column()
253
+ else:
254
+ # all good, do nothing
255
+ pass
256
+
257
+ all_states = self.df_tracks.loc[:, self.status_name].tolist()
258
+ all_states = np.array(all_states)
259
+ self.state_color_map = color_from_state(all_states, recently_modified=False)
260
+ self.df_tracks["group_color"] = self.df_tracks[self.status_name].apply(
261
+ self.assign_color_state
262
+ )
263
+
264
+ self.df_tracks = self.df_tracks.dropna(subset=["POSITION_X", "POSITION_Y"])
265
+ self.df_tracks["x_anim"] = self.df_tracks["POSITION_X"]
266
+ self.df_tracks["y_anim"] = self.df_tracks["POSITION_Y"]
267
+ self.df_tracks["x_anim"] = self.df_tracks["x_anim"].astype(int)
268
+ self.df_tracks["y_anim"] = self.df_tracks["y_anim"].astype(int)
269
+
270
+ self.extract_scatter_from_trajectories()
271
+ if "TRACK_ID" in self.df_tracks.columns:
272
+ self.track_of_interest = self.df_tracks.dropna(subset="TRACK_ID")[
273
+ "TRACK_ID"
274
+ ].min()
275
+ else:
276
+ self.track_of_interest = self.df_tracks.dropna(subset="ID")["ID"].min()
277
+
278
+ self.loc_t = []
279
+ self.loc_idx = []
280
+ for t in range(len(self.tracks)):
281
+ indices = np.where(self.tracks[t] == self.track_of_interest)[0]
282
+ if len(indices) > 0:
283
+ self.loc_t.append(t)
284
+ self.loc_idx.append(indices[0])
285
+
286
+ from sklearn.preprocessing import MinMaxScaler
287
+
288
+ self.MinMaxScaler = MinMaxScaler()
289
+ self.columns_to_rescale = list(self.df_tracks.columns)
290
+
291
+ cols_to_remove = [
292
+ "group",
293
+ "group_color",
294
+ "status",
295
+ "status_color",
296
+ "class_color",
297
+ "TRACK_ID",
298
+ "FRAME",
299
+ "x_anim",
300
+ "y_anim",
301
+ "t",
302
+ "dummy",
303
+ "group_color",
304
+ "state",
305
+ "generation",
306
+ "root",
307
+ "parent",
308
+ "class_id",
309
+ "class",
310
+ "t0",
311
+ "POSITION_X",
312
+ "POSITION_Y",
313
+ "position",
314
+ "well",
315
+ "well_index",
316
+ "well_name",
317
+ "pos_name",
318
+ "index",
319
+ "concentration",
320
+ "cell_type",
321
+ "antibody",
322
+ "pharmaceutical_agent",
323
+ "ID",
324
+ ] + self.class_cols
325
+
326
+ from celldetective.utils.experiment import (
327
+ get_experiment_metadata,
328
+ get_experiment_labels,
329
+ )
330
+
331
+ meta = get_experiment_metadata(self.exp_dir)
332
+ if meta is not None:
333
+ keys = list(meta.keys())
334
+ cols_to_remove.extend(keys)
335
+
336
+ labels = get_experiment_labels(self.exp_dir)
337
+ if labels is not None:
338
+ keys = list(labels.keys())
339
+ cols_to_remove.extend(labels)
340
+
341
+ for tr in cols_to_remove:
342
+ try:
343
+ self.columns_to_rescale.remove(tr)
344
+ except:
345
+ pass
346
+
347
+ x = self.df_tracks[self.columns_to_rescale].values
348
+ self.MinMaxScaler.fit(x)
349
+
350
+ def populate_options_layout(self):
351
+ # clear options hbox
352
+ for i in reversed(range(self.options_hbox.count())):
353
+ self.options_hbox.itemAt(i).widget().setParent(None)
354
+
355
+ time_option_hbox = QHBoxLayout()
356
+ time_option_hbox.setContentsMargins(100, 0, 100, 0)
357
+ time_option_hbox.setSpacing(0)
358
+
359
+ self.time_of_interest_label = QLabel("phenotype: ")
360
+ time_option_hbox.addWidget(self.time_of_interest_label, 30)
361
+
362
+ self.time_of_interest_le = QLineEdit()
363
+ self.time_of_interest_le.setValidator(self.int_validator)
364
+ time_option_hbox.addWidget(self.time_of_interest_le)
365
+
366
+ self.suppr_btn = QPushButton("")
367
+ self.suppr_btn.setStyleSheet(self.button_select_all)
368
+ self.suppr_btn.setIcon(icon(MDI6.delete, color="black"))
369
+ self.suppr_btn.setToolTip("Delete cell")
370
+ self.suppr_btn.setIconSize(QSize(20, 20))
371
+ self.suppr_btn.clicked.connect(self.del_cell)
372
+ time_option_hbox.addWidget(self.suppr_btn)
373
+
374
+ self.options_hbox.addLayout(time_option_hbox)
375
+
376
+ def update_widgets(self):
377
+
378
+ self.class_label.setText("characteristic \n group: ")
379
+ self.update_class_cb()
380
+ self.add_class_btn.setToolTip("Add a new characteristic group")
381
+ self.del_class_btn.setToolTip("Delete a characteristic group")
382
+
383
+ self.export_btn.disconnect()
384
+ self.export_btn.clicked.connect(self.export_measurements)
385
+
386
+ def update_class_cb(self):
387
+
388
+ self.class_choice_cb.disconnect()
389
+ self.class_choice_cb.clear()
390
+ cols = np.array(self.df_tracks.columns)
391
+ self.class_cols = np.array(
392
+ [
393
+ c.startswith("group")
394
+ or c.startswith("status")
395
+ or (
396
+ c.startswith("class")
397
+ and not c.endswith("_id")
398
+ and not c.endswith("_color")
399
+ )
400
+ for c in list(self.df_tracks.columns)
401
+ ]
402
+ )
403
+ self.class_cols = list(cols[self.class_cols])
404
+
405
+ to_remove = [
406
+ "group_id",
407
+ "group_color",
408
+ "class_id",
409
+ "class_color",
410
+ "status_color",
411
+ "status_id",
412
+ ]
413
+ for col in to_remove:
414
+ while col in self.class_cols:
415
+ self.class_cols.remove(col)
416
+
417
+ # Filter to keep only group_* and status_* as requested by user, but allow 'group' if it exists
418
+ final_cols = []
419
+ for c in self.class_cols:
420
+ if c == "group" or c.startswith("group_") or c.startswith("status_"):
421
+ final_cols.append(c)
422
+
423
+ self.class_cols = final_cols
424
+
425
+ self.class_choice_cb.addItems(self.class_cols)
426
+ if self.status_name in self.class_cols:
427
+ self.class_choice_cb.setCurrentText(self.status_name)
428
+ self.class_choice_cb.currentIndexChanged.connect(self.changed_class)
429
+
430
+ def populate_window(self):
431
+
432
+ super().populate_window()
433
+ # Left panel updates
434
+ self.populate_options_layout()
435
+ self.update_widgets()
436
+
437
+ self.annotation_btns_to_hide = [
438
+ self.time_of_interest_label,
439
+ self.time_of_interest_le,
440
+ self.suppr_btn,
441
+ ]
442
+ self.hide_annotation_buttons()
443
+
444
+ # Right panel - Initialize StackVisualizer
445
+ self.viewer = AnnotatorStackVisualizer(
446
+ stack_path=self.stack_path,
447
+ labels=self.labels,
448
+ frame_slider=True,
449
+ channel_cb=True,
450
+ channel_names=self.channel_names,
451
+ n_channels=self.nbr_channels,
452
+ target_channel=0,
453
+ window_title="Stack Viewer",
454
+ )
455
+
456
+ # Connect viewer signals
457
+ self.viewer.frame_slider.valueChanged.connect(self.sync_frame)
458
+ self.viewer.channel_cb.currentIndexChanged.connect(self.plot_signals)
459
+
460
+ # Connect mpl event
461
+ self.cid_pick = self.viewer.fig.canvas.mpl_connect(
462
+ "pick_event", self.on_scatter_pick
463
+ )
464
+
465
+ self.right_panel.addWidget(self.viewer.canvas)
466
+
467
+ # Force start at frame 0
468
+ self.viewer.frame_slider.setValue(0)
469
+
470
+ self.plot_signals()
471
+ self.compact_layout_main()
472
+
473
+ def compact_layout_main(self):
474
+ # Attempt to compact the viewer layout one more time from the main window side
475
+ if hasattr(self, "viewer"):
476
+ self.viewer.compact_layout()
477
+
478
+ def sync_frame(self, value):
479
+ """Callback when StackVisualizer frame changes"""
480
+
481
+ self.current_frame = value
482
+ self.update_frame_logic()
483
+
484
+ def plot_signals(self):
485
+ """Delegate signal plotting but check for viewer availability"""
486
+ if not hasattr(self, "viewer"):
487
+ return
488
+
489
+ # Call the original plot_signals logic or adapt it.
490
+ # Since plot_signals uses self.cell_ax (left panel), it should be fine.
491
+ # However, we need to ensure it uses the correct current_frame.
492
+
493
+ current_frame = self.current_frame
494
+
495
+ yvalues = []
496
+ all_yvalues = []
497
+ current_yvalues = []
498
+ labels = []
499
+ range_values = []
500
+
501
+ for i in range(len(self.signal_choice_cb)):
502
+ signal_choice = self.signal_choice_cb[i].currentText()
503
+
504
+ if signal_choice != "--":
505
+ if "TRACK_ID" in self.df_tracks.columns:
506
+ ydata = self.df_tracks.loc[
507
+ (self.df_tracks["TRACK_ID"] == self.track_of_interest)
508
+ & (self.df_tracks["FRAME"] == current_frame),
509
+ signal_choice,
510
+ ].to_numpy()
511
+ else:
512
+ ydata = self.df_tracks.loc[
513
+ (self.df_tracks["ID"] == self.track_of_interest), signal_choice
514
+ ].to_numpy()
515
+ all_ydata = self.df_tracks.loc[:, signal_choice].to_numpy()
516
+ ydataNaN = ydata
517
+ ydata = ydata[ydata == ydata] # remove nan
518
+
519
+ current_ydata = self.df_tracks.loc[
520
+ (self.df_tracks["FRAME"] == current_frame), signal_choice
521
+ ].to_numpy()
522
+ current_ydata = current_ydata[current_ydata == current_ydata]
523
+ all_ydata = all_ydata[all_ydata == all_ydata]
524
+ yvalues.extend(ydataNaN)
525
+ current_yvalues.append(current_ydata)
526
+ all_yvalues.append(all_ydata)
527
+ range_values.extend(all_ydata)
528
+ labels.append(signal_choice)
529
+
530
+ self.cell_ax.clear()
531
+ if self.log_scale:
532
+ self.cell_ax.set_yscale("log")
533
+ else:
534
+ self.cell_ax.set_yscale("linear")
535
+
536
+ if len(yvalues) > 0:
537
+ try:
538
+ self.cell_ax.boxplot(all_yvalues, showfliers=self.show_fliers)
539
+ except Exception as e:
540
+ logger.error(f"{e=}")
541
+
542
+ x_pos = np.arange(len(all_yvalues)) + 1
543
+ for index, feature in enumerate(current_yvalues):
544
+ x_values_strip = (index + 1) + np.random.normal(
545
+ 0, 0.04, size=len(feature)
546
+ )
547
+ self.cell_ax.plot(
548
+ x_values_strip,
549
+ feature,
550
+ marker="o",
551
+ linestyle="None",
552
+ color=tab10.colors[0],
553
+ alpha=0.1,
554
+ )
555
+ self.cell_ax.plot(
556
+ x_pos,
557
+ yvalues,
558
+ marker="H",
559
+ linestyle="None",
560
+ color=tab10.colors[3],
561
+ alpha=1,
562
+ )
563
+ range_values = np.array(range_values)
564
+ range_values = range_values[range_values == range_values]
565
+
566
+ # Filter out non-positive values if log scale is active to prevent warnings
567
+ if self.log_scale:
568
+ range_values = range_values[range_values > 0]
569
+
570
+ if len(range_values) > 0:
571
+ self.value_magnitude = np.nanmin(range_values) - 0.03 * (
572
+ np.nanmax(range_values) - np.nanmin(range_values)
573
+ )
574
+ else:
575
+ self.value_magnitude = 1
576
+
577
+ self.non_log_ymin = np.nanmin(range_values) - 0.03 * (
578
+ np.nanmax(range_values) - np.nanmin(range_values)
579
+ )
580
+ self.non_log_ymax = np.nanmax(range_values) + 0.03 * (
581
+ np.nanmax(range_values) - np.nanmin(range_values)
582
+ )
583
+ if self.cell_ax.get_yscale() == "linear":
584
+ self.cell_ax.set_ylim(self.non_log_ymin, self.non_log_ymax)
585
+ else:
586
+ self.cell_ax.set_ylim(self.value_magnitude, self.non_log_ymax)
587
+ else:
588
+ self.cell_ax.text(
589
+ 0.5,
590
+ 0.5,
591
+ "No data available",
592
+ horizontalalignment="center",
593
+ verticalalignment="center",
594
+ transform=self.cell_ax.transAxes,
595
+ )
596
+
597
+ self.cell_fcanvas.canvas.draw()
598
+
599
+ def plot_red_points(self, ax):
600
+ yvalues = []
601
+ current_frame = self.current_frame
602
+ for i in range(len(self.signal_choice_cb)):
603
+ signal_choice = self.signal_choice_cb[i].currentText()
604
+ if signal_choice != "--":
605
+ if "TRACK_ID" in self.df_tracks.columns:
606
+ ydata = self.df_tracks.loc[
607
+ (self.df_tracks["TRACK_ID"] == self.track_of_interest)
608
+ & (self.df_tracks["FRAME"] == current_frame),
609
+ signal_choice,
610
+ ].to_numpy()
611
+ else:
612
+ ydata = self.df_tracks.loc[
613
+ (self.df_tracks["ID"] == self.track_of_interest)
614
+ & (self.df_tracks["FRAME"] == current_frame),
615
+ signal_choice,
616
+ ].to_numpy()
617
+ ydata = ydata[ydata == ydata] # remove nan
618
+ yvalues.extend(ydata)
619
+ x_pos = np.arange(len(yvalues)) + 1
620
+ ax.plot(
621
+ x_pos, yvalues, marker="H", linestyle="None", color=tab10.colors[3], alpha=1
622
+ ) # Plot red points representing cells
623
+ self.cell_fcanvas.canvas.draw()
624
+
625
+ def select_single_cell(self, index, timepoint):
626
+
627
+ self.correct_btn.setEnabled(True)
628
+ self.cancel_btn.setEnabled(True)
629
+ self.del_shortcut.setEnabled(True)
630
+
631
+ self.track_of_interest = self.tracks[timepoint][index]
632
+ logger.info(f"You selected cell #{self.track_of_interest}...")
633
+ self.give_cell_information()
634
+
635
+ if len(self.cell_ax.lines) > 0:
636
+ self.cell_ax.lines[
637
+ -1
638
+ ].remove() # Remove the last line (red points) from the plot
639
+ self.plot_red_points(self.cell_ax)
640
+ else:
641
+ self.plot_signals()
642
+
643
+ self.loc_t = []
644
+ self.loc_idx = []
645
+ for t in range(len(self.tracks)):
646
+ indices = np.where(self.tracks[t] == self.track_of_interest)[0]
647
+ if len(indices) > 0:
648
+ self.loc_t.append(t)
649
+ self.loc_idx.append(indices[0])
650
+
651
+ self.previous_color = []
652
+ for t, idx in zip(self.loc_t, self.loc_idx):
653
+ self.previous_color.append(self.colors[t][idx].copy())
654
+ self.colors[t][idx] = "lime"
655
+
656
+ self.draw_frame(self.current_frame)
657
+
658
+ def cancel_selection(self):
659
+ super().cancel_selection()
660
+ self.event = None
661
+ self.draw_frame(self.current_frame)
662
+
663
+ def export_measurements(self):
664
+ logger.info("User interactions: Exporting measurements...")
665
+ # Implementation same as before
666
+ auto_dataset_name = (
667
+ self.pos.split(os.sep)[-4]
668
+ + "_"
669
+ + self.pos.split(os.sep)[-2]
670
+ + f"_{str(self.current_frame).zfill(3)}"
671
+ + f"_{self.status_name}.npy"
672
+ )
673
+
674
+ if self.normalized_signals:
675
+ self.normalize_features_btn.click()
676
+
677
+ subdf = self.df_tracks.loc[self.df_tracks["FRAME"] == self.current_frame, :]
678
+ subdf["class"] = subdf[self.status_name]
679
+ dico = subdf.to_dict("records")
680
+
681
+ pathsave = QFileDialog.getSaveFileName(
682
+ self, "Select file name", self.exp_dir + auto_dataset_name, ".npy"
683
+ )[0]
684
+ if pathsave != "":
685
+ if not pathsave.endswith(".npy"):
686
+ pathsave += ".npy"
687
+ try:
688
+ np.save(pathsave, dico)
689
+ logger.info(f"File successfully written in {pathsave}.")
690
+ except Exception as e:
691
+ logger.error(f"Error {e}...")
692
+
693
+ def write_new_event_class(self):
694
+
695
+ if self.class_name_le.text() == "":
696
+ self.target_class = "group"
697
+ else:
698
+ self.target_class = "group_" + self.class_name_le.text()
699
+
700
+ logger.info(
701
+ f"User interactions: Creating new characteristic group '{self.target_class}'"
702
+ )
703
+
704
+ if self.target_class in list(self.df_tracks.columns):
705
+ msgBox = QMessageBox()
706
+ msgBox.setIcon(QMessageBox.Warning)
707
+ msgBox.setText(
708
+ "This characteristic group name already exists. If you proceed,\nall annotated data will be rewritten. Do you wish to continue?"
709
+ )
710
+ msgBox.setWindowTitle("Warning")
711
+ msgBox.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
712
+ returnValue = msgBox.exec()
713
+ if returnValue == QMessageBox.No:
714
+ return None
715
+ else:
716
+ pass
717
+
718
+ self.df_tracks.loc[:, self.target_class] = 0
719
+
720
+ self.update_class_cb()
721
+
722
+ idx = self.class_choice_cb.findText(self.target_class)
723
+ self.status_name = self.target_class
724
+ self.class_choice_cb.setCurrentIndex(idx)
725
+ self.newClassWidget.close()
726
+
727
+ def hide_annotation_buttons(self):
728
+
729
+ for a in self.annotation_btns_to_hide:
730
+ a.hide()
731
+ self.time_of_interest_label.setEnabled(False)
732
+ self.time_of_interest_le.setText("")
733
+ self.time_of_interest_le.setEnabled(False)
734
+
735
+ def show_annotation_buttons(self):
736
+
737
+ for a in self.annotation_btns_to_hide:
738
+ a.show()
739
+
740
+ self.time_of_interest_label.setEnabled(True)
741
+ self.time_of_interest_le.setEnabled(True)
742
+ self.correct_btn.setText("submit")
743
+
744
+ self.correct_btn.disconnect()
745
+ self.correct_btn.clicked.connect(self.apply_modification)
746
+
747
+ def give_cell_information(self):
748
+
749
+ try:
750
+ cell_selected = f"cell: {self.track_of_interest}\n"
751
+ if self.status_name in self.df_tracks.columns:
752
+ if "TRACK_ID" in self.df_tracks.columns:
753
+ val = self.df_tracks.loc[
754
+ (self.df_tracks["FRAME"] == self.current_frame)
755
+ & (self.df_tracks["TRACK_ID"] == self.track_of_interest),
756
+ self.status_name,
757
+ ].to_numpy()
758
+ if len(val) > 0:
759
+ cell_status = f"phenotype: {val[0]}\n"
760
+ else:
761
+ cell_status = "phenotype: N/A\n"
762
+ else:
763
+ val = self.df_tracks.loc[
764
+ self.df_tracks["ID"] == self.track_of_interest, self.status_name
765
+ ].to_numpy()
766
+ if len(val) > 0:
767
+ cell_status = f"phenotype: {val[0]}\n"
768
+ else:
769
+ cell_status = "phenotype: N/A\n"
770
+ else:
771
+ cell_status = f"phenotype: N/A (col '{self.status_name}' missing)\n"
772
+ self.cell_info.setText(cell_selected + cell_status)
773
+ except Exception as e:
774
+ logger.error(f"Error in give_cell_information: {e}")
775
+
776
+ def create_new_event_class(self):
777
+
778
+ # display qwidget to name the event
779
+ self.newClassWidget = CelldetectiveWidget()
780
+ self.newClassWidget.setWindowTitle("Create new characteristic group")
781
+
782
+ layout = QVBoxLayout()
783
+ self.newClassWidget.setLayout(layout)
784
+ name_hbox = QHBoxLayout()
785
+ name_hbox.addWidget(QLabel("group name: "), 25)
786
+ self.class_name_le = QLineEdit("group")
787
+ name_hbox.addWidget(self.class_name_le, 75)
788
+ layout.addLayout(name_hbox)
789
+
790
+ btn_hbox = QHBoxLayout()
791
+ submit_btn = QPushButton("submit")
792
+ cancel_btn = QPushButton("cancel")
793
+ btn_hbox.addWidget(cancel_btn, 50)
794
+ btn_hbox.addWidget(submit_btn, 50)
795
+ layout.addLayout(btn_hbox)
796
+
797
+ submit_btn.clicked.connect(self.write_new_event_class)
798
+ cancel_btn.clicked.connect(self.close_without_new_class)
799
+
800
+ self.newClassWidget.show()
801
+ center_window(self.newClassWidget)
802
+
803
+ def apply_modification(self):
804
+ if self.time_of_interest_le.text() != "":
805
+ status = int(self.time_of_interest_le.text())
806
+ else:
807
+ status = 0
808
+
809
+ logger.info(
810
+ f"User interactions: Reclassifying cell #{self.track_of_interest} at frame {self.current_frame} to status {status}"
811
+ )
812
+ if "TRACK_ID" in self.df_tracks.columns:
813
+ self.df_tracks.loc[
814
+ (self.df_tracks["TRACK_ID"] == self.track_of_interest)
815
+ & (self.df_tracks["FRAME"] == self.current_frame),
816
+ self.status_name,
817
+ ] = status
818
+
819
+ indices = self.df_tracks.index[
820
+ (self.df_tracks["TRACK_ID"] == self.track_of_interest)
821
+ & (self.df_tracks["FRAME"] == self.current_frame)
822
+ ]
823
+ else:
824
+ self.df_tracks.loc[
825
+ (self.df_tracks["ID"] == self.track_of_interest)
826
+ & (self.df_tracks["FRAME"] == self.current_frame),
827
+ self.status_name,
828
+ ] = status
829
+
830
+ indices = self.df_tracks.index[
831
+ (self.df_tracks["ID"] == self.track_of_interest)
832
+ & (self.df_tracks["FRAME"] == self.current_frame)
833
+ ]
834
+
835
+ self.df_tracks.loc[indices, self.status_name] = status
836
+ all_states = self.df_tracks.loc[:, self.status_name].tolist()
837
+ all_states = np.array(all_states)
838
+ self.state_color_map = color_from_state(all_states, recently_modified=False)
839
+
840
+ self.df_tracks["group_color"] = self.df_tracks[self.status_name].apply(
841
+ self.assign_color_state
842
+ )
843
+ self.extract_scatter_from_trajectories()
844
+ self.give_cell_information()
845
+
846
+ self.correct_btn.disconnect()
847
+ self.correct_btn.clicked.connect(self.show_annotation_buttons)
848
+
849
+ self.hide_annotation_buttons()
850
+ self.correct_btn.setEnabled(False)
851
+ self.correct_btn.setText("correct")
852
+ self.cancel_btn.setEnabled(False)
853
+ self.del_shortcut.setEnabled(False)
854
+
855
+ if len(self.selection) > 0:
856
+ self.selection.pop(0)
857
+
858
+ self.draw_frame(self.current_frame)
859
+
860
+ def assign_color_state(self, state):
861
+
862
+ try:
863
+ if np.isnan(state):
864
+ state = "nan"
865
+ except TypeError:
866
+ pass
867
+ return self.state_color_map[state]
868
+
869
+ def on_scatter_pick(self, event):
870
+ """Handle pick event on scatter plot."""
871
+ self.event = event
872
+ ind = event.ind
873
+
874
+ if len(ind) > 1:
875
+ # disambiguate based on distance to mouse click
876
+ datax, datay = [self.positions[self.current_frame][i, 0] for i in ind], [
877
+ self.positions[self.current_frame][i, 1] for i in ind
878
+ ]
879
+ msx, msy = event.mouseevent.xdata, event.mouseevent.ydata
880
+ dist = np.sqrt((np.array(datax) - msx) ** 2 + (np.array(datay) - msy) ** 2)
881
+ ind = [ind[np.argmin(dist)]]
882
+
883
+ if len(ind) > 0:
884
+ # We have a single point
885
+ idx = ind[0]
886
+
887
+ # Enforce single selection / Toggle
888
+ if len(self.selection) > 0:
889
+ # Check if we clicked the same cell
890
+ prev_idx, prev_frame = self.selection[0]
891
+ if (
892
+ prev_idx == idx
893
+ ): # and prev_frame == self.current_frame (implicit since we pick on current frame)
894
+
895
+ self.cancel_selection()
896
+ return
897
+
898
+ self.cancel_selection()
899
+
900
+ self.selection = [[idx, self.current_frame]]
901
+ self.select_single_cell(idx, self.current_frame)
902
+
903
+ def draw_frame(self, framedata):
904
+ """
905
+ Update plot elements at each timestep of the loop.
906
+ Using StackVisualizer overlay update
907
+ """
908
+ self.framedata = framedata
909
+
910
+ # Prepare overlays
911
+ if self.framedata < len(self.positions):
912
+ pos = self.positions[self.framedata]
913
+ cols = self.colors[self.framedata][:, 0]
914
+ else:
915
+ pos = []
916
+ cols = []
917
+
918
+ # No need to manage contour cache or labels here
919
+ # CellEdgeVisualizer handles it.
920
+
921
+ # Update Viewer Scatter Only
922
+ # Note: Mask is updated automatically by CellEdgeVisualizer's change_frame
923
+ self.viewer.update_overlays(
924
+ positions=pos,
925
+ colors=cols,
926
+ )
927
+
928
+ def make_status_column(self):
929
+ if self.status_name == "state_firstdetection":
930
+ pass
931
+ else:
932
+ self.df_tracks.loc[:, self.status_name] = 0
933
+ all_states = self.df_tracks.loc[:, self.status_name].tolist()
934
+ all_states = np.array(all_states)
935
+ self.state_color_map = color_from_state(all_states, recently_modified=False)
936
+ self.df_tracks["group_color"] = self.df_tracks[self.status_name].apply(
937
+ self.assign_color_state
938
+ )
939
+
940
+ def extract_scatter_from_trajectories(self):
941
+
942
+ self.positions = []
943
+ self.colors = []
944
+ self.tracks = []
945
+
946
+ for t in np.arange(self.len_movie):
947
+ self.positions.append(
948
+ self.df_tracks.loc[
949
+ self.df_tracks["FRAME"] == t, ["POSITION_X", "POSITION_Y"]
950
+ ].to_numpy()
951
+ )
952
+ self.colors.append(
953
+ self.df_tracks.loc[
954
+ self.df_tracks["FRAME"] == t, ["group_color"]
955
+ ].to_numpy()
956
+ )
957
+ if "TRACK_ID" in self.df_tracks.columns:
958
+ self.tracks.append(
959
+ self.df_tracks.loc[
960
+ self.df_tracks["FRAME"] == t, "TRACK_ID"
961
+ ].to_numpy()
962
+ )
963
+ else:
964
+ self.tracks.append(
965
+ self.df_tracks.loc[self.df_tracks["FRAME"] == t, "ID"].to_numpy()
966
+ )
967
+
968
+ def compute_status_and_colors(self, index=0):
969
+ self.changed_class()
970
+
971
+ def changed_class(self):
972
+ self.status_name = self.class_choice_cb.currentText()
973
+ if self.status_name != "":
974
+ # self.compute_status_and_colors()
975
+ self.modify()
976
+ self.draw_frame(self.current_frame)
977
+
978
+ def update_frame_logic(self):
979
+ """
980
+ Logic to execute when frame changes.
981
+ """
982
+ # Auto-switch track of interest if ID mode
983
+ if "TRACK_ID" in list(self.df_tracks.columns):
984
+ pass
985
+ elif "ID" in list(self.df_tracks.columns):
986
+ # print("ID in cols... change class of interest... ")
987
+ candidates = self.df_tracks[self.df_tracks["FRAME"] == self.current_frame][
988
+ "ID"
989
+ ]
990
+ if not candidates.empty:
991
+ self.track_of_interest = candidates.min()
992
+ self.modify()
993
+
994
+ self.draw_frame(self.current_frame)
995
+ self.give_cell_information()
996
+ self.plot_signals()
997
+
998
+ def changed_channel(self):
999
+ """Handled by StackViewer mostly, but we might need to refresh plotting if things depend on channel"""
1000
+ pass # StackViewer handles image reload
1001
+
1002
+ def save_trajectories(self):
1003
+ logger.info(f"Saving trajectories...")
1004
+ if self.normalized_signals:
1005
+ self.normalize_features_btn.click()
1006
+ if self.selection:
1007
+ self.cancel_selection()
1008
+
1009
+ # Avoid crash if status doesn't exist or is special
1010
+ # self.df_tracks = self.df_tracks.drop(
1011
+ # self.df_tracks[self.df_tracks[self.status_name] == 99].index
1012
+ # )
1013
+
1014
+ try:
1015
+ self.df_tracks.drop(columns="", inplace=True)
1016
+ except:
1017
+ pass
1018
+ try:
1019
+ self.df_tracks.drop(columns="group_color", inplace=True)
1020
+ except:
1021
+ pass
1022
+ try:
1023
+ self.df_tracks.drop(columns="x_anim", inplace=True)
1024
+ except:
1025
+ pass
1026
+ try:
1027
+ self.df_tracks.drop(columns="y_anim", inplace=True)
1028
+ except:
1029
+ pass
1030
+
1031
+ self.df_tracks.to_csv(self.trajectories_path, index=False)
1032
+ logger.info("Table successfully exported...")
1033
+
1034
+ self.locate_tracks()
1035
+ self.changed_class()
1036
+
1037
+ def modify(self):
1038
+ if self.status_name not in self.df_tracks.columns:
1039
+ logger.warning(
1040
+ f"Column '{self.status_name}' not found in df_tracks. Skipping modify."
1041
+ )
1042
+ return
1043
+
1044
+ all_states = self.df_tracks.loc[:, self.status_name].tolist()
1045
+ all_states = np.array(all_states)
1046
+ self.state_color_map = color_from_state(all_states, recently_modified=False)
1047
+
1048
+ self.df_tracks["group_color"] = self.df_tracks[self.status_name].apply(
1049
+ self.assign_color_state
1050
+ )
1051
+
1052
+ self.extract_scatter_from_trajectories()
1053
+ self.give_cell_information()
1054
+
1055
+ self.correct_btn.disconnect()
1056
+ self.correct_btn.clicked.connect(self.show_annotation_buttons)
1057
+
1058
+ def del_cell(self):
1059
+ logger.info(
1060
+ f"User interactions: Deleting cell #{self.track_of_interest} (setting status to 99)"
1061
+ )
1062
+ self.time_of_interest_le.setEnabled(False)
1063
+ self.time_of_interest_le.setText("99")
1064
+ self.apply_modification()