celldetective 1.1.1.post4__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. celldetective/__init__.py +2 -1
  2. celldetective/extra_properties.py +62 -34
  3. celldetective/gui/__init__.py +1 -0
  4. celldetective/gui/analyze_block.py +2 -1
  5. celldetective/gui/classifier_widget.py +15 -9
  6. celldetective/gui/control_panel.py +50 -6
  7. celldetective/gui/layouts.py +5 -4
  8. celldetective/gui/neighborhood_options.py +13 -9
  9. celldetective/gui/plot_signals_ui.py +39 -11
  10. celldetective/gui/process_block.py +413 -95
  11. celldetective/gui/retrain_segmentation_model_options.py +17 -4
  12. celldetective/gui/retrain_signal_model_options.py +106 -6
  13. celldetective/gui/signal_annotator.py +29 -9
  14. celldetective/gui/signal_annotator2.py +2708 -0
  15. celldetective/gui/signal_annotator_options.py +3 -1
  16. celldetective/gui/survival_ui.py +15 -6
  17. celldetective/gui/tableUI.py +222 -60
  18. celldetective/io.py +536 -420
  19. celldetective/measure.py +919 -969
  20. celldetective/models/pair_signal_detection/blank +0 -0
  21. celldetective/models/segmentation_effectors/ricm-bimodal/config_input.json +130 -0
  22. celldetective/models/segmentation_effectors/ricm-bimodal/ricm-bimodal +0 -0
  23. celldetective/models/segmentation_effectors/ricm-bimodal/training_instructions.json +37 -0
  24. celldetective/neighborhood.py +428 -354
  25. celldetective/relative_measurements.py +648 -0
  26. celldetective/scripts/analyze_signals.py +1 -1
  27. celldetective/scripts/measure_cells.py +28 -8
  28. celldetective/scripts/measure_relative.py +103 -0
  29. celldetective/scripts/segment_cells.py +5 -5
  30. celldetective/scripts/track_cells.py +4 -1
  31. celldetective/scripts/train_segmentation_model.py +23 -18
  32. celldetective/scripts/train_signal_model.py +33 -0
  33. celldetective/signals.py +405 -8
  34. celldetective/tracking.py +8 -2
  35. celldetective/utils.py +178 -17
  36. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/METADATA +8 -8
  37. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/RECORD +41 -34
  38. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/WHEEL +1 -1
  39. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/LICENSE +0 -0
  40. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/entry_points.txt +0 -0
  41. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,9 @@ class ConfigSignalAnnotator(QMainWindow, Styles):
34
34
  self.instructions_path = self.parent_window.exp_dir + "configs/signal_annotator_config_targets.json"
35
35
  elif self.mode=="effectors":
36
36
  self.instructions_path = self.parent_window.exp_dir + "configs/signal_annotator_config_effectors.json"
37
-
37
+ elif self.mode == "neighborhood":
38
+ self.instructions_path = self.parent_window.exp_dir + "configs/signal_annotator_config_neighborhood.json"
39
+
38
40
  exp_config = self.exp_dir +"config.ini"
39
41
  #self.config_path = self.exp_dir + self.config_name
40
42
  self.channel_names, self.channels = extract_experiment_channels(exp_config)
@@ -109,9 +109,10 @@ class ConfigSurvival(QWidget, Styles):
109
109
  main_layout.addWidget(panel_title, alignment=Qt.AlignCenter)
110
110
 
111
111
 
112
- labels = [QLabel('population: '), QLabel('time of\nreference: '), QLabel('time of\ninterest: '), QLabel('exclude\nclass: '), QLabel('cmap: ')] #QLabel('class: '),
113
- self.cb_options = [['targets','effectors'], ['0','t_firstdetection'], ['t0'], ['--'], list(plt.colormaps())] #['class'],
112
+ labels = [QLabel('population: '), QLabel('time of\nreference: '), QLabel('time of\ninterest: '), QLabel('cmap: ')] #QLabel('class: '),
113
+ self.cb_options = [['targets','effectors'], ['0'], [], list(plt.colormaps())] #['class'],
114
114
  self.cbs = [QComboBox() for i in range(len(labels))]
115
+
115
116
  self.cbs[-1] = QColormapComboBox()
116
117
  self.cbs[0].currentIndexChanged.connect(self.set_classes_and_times)
117
118
 
@@ -133,6 +134,12 @@ class ConfigSurvival(QWidget, Styles):
133
134
 
134
135
  main_layout.addLayout(choice_layout)
135
136
 
137
+ select_layout = QHBoxLayout()
138
+ select_layout.addWidget(QLabel('select cells\nwith query: '), 33)
139
+ self.query_le = QLineEdit()
140
+ select_layout.addWidget(self.query_le, 66)
141
+ main_layout.addLayout(select_layout)
142
+
136
143
  self.cbs[0].setCurrentIndex(0)
137
144
  self.cbs[1].setCurrentText('t_firstdetection')
138
145
 
@@ -218,10 +225,12 @@ class ConfigSurvival(QWidget, Styles):
218
225
 
219
226
  if self.df is not None:
220
227
 
221
- excluded_class = self.cbs[3].currentText()
222
- if excluded_class!='--':
223
- print(f"Excluding {excluded_class}...")
224
- self.df = self.df.loc[~(self.df[excluded_class].isin([0,2])),:]
228
+ try:
229
+ query_text = self.query_le.text()
230
+ if query_text != '':
231
+ self.df = self.df.query(query_text)
232
+ except Exception as e:
233
+ print(e, ' The query is misunderstood and will not be applied...')
225
234
 
226
235
  self.compute_survival_functions()
227
236
  # prepare survival
@@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
5
5
  from matplotlib.cm import viridis
6
6
  plt.rcParams['svg.fonttype'] = 'none'
7
7
  from celldetective.gui.gui_utils import FigureCanvas, center_window
8
- from celldetective.utils import differentiate_per_track
8
+ from celldetective.utils import differentiate_per_track, collapse_trajectories_by_status
9
9
  import numpy as np
10
10
  import seaborn as sns
11
11
  import matplotlib.cm as mcm
@@ -20,6 +20,10 @@ from matplotlib import colormaps
20
20
 
21
21
  class PandasModel(QAbstractTableModel):
22
22
 
23
+ """
24
+ from https://stackoverflow.com/questions/31475965/fastest-way-to-populate-qtableview-from-pandas-data-frame
25
+ """
26
+
23
27
  def __init__(self, data):
24
28
  QAbstractTableModel.__init__(self)
25
29
  self._data = data
@@ -67,7 +71,7 @@ class QueryWidget(QWidget):
67
71
  try:
68
72
  query_text = self.query_le.text() #.replace('class', '`class`')
69
73
  tab = self.parent_window.data.query(query_text)
70
- self.subtable = TableUI(tab, query_text, plot_mode="static")
74
+ self.subtable = TableUI(tab, query_text, plot_mode="static", population=self.parent_window.population)
71
75
  self.subtable.show()
72
76
  self.close()
73
77
  except Exception as e:
@@ -236,7 +240,47 @@ class DifferentiateColWidget(QWidget, Styles):
236
240
  self.parent_window.table_view.setModel(self.parent_window.model)
237
241
  self.close()
238
242
 
243
+ class AbsColWidget(QWidget, Styles):
239
244
 
245
+ def __init__(self, parent_window, column=None):
246
+
247
+ super().__init__()
248
+ self.parent_window = parent_window
249
+ self.column = column
250
+
251
+ self.setWindowTitle("abs(.)")
252
+ # Create the QComboBox and add some items
253
+ center_window(self)
254
+
255
+ layout = QVBoxLayout(self)
256
+ layout.setContentsMargins(30,30,30,30)
257
+
258
+ self.measurements_cb = QComboBox()
259
+ self.measurements_cb.addItems(list(self.parent_window.data.columns))
260
+ if self.column is not None:
261
+ idx = self.measurements_cb.findText(self.column)
262
+ self.measurements_cb.setCurrentIndex(idx)
263
+
264
+ measurement_layout = QHBoxLayout()
265
+ measurement_layout.addWidget(QLabel('measurements: '), 25)
266
+ measurement_layout.addWidget(self.measurements_cb, 75)
267
+ layout.addLayout(measurement_layout)
268
+
269
+ self.submit_btn = QPushButton('Compute')
270
+ self.submit_btn.setStyleSheet(self.button_style_sheet)
271
+ self.submit_btn.clicked.connect(self.compute_abs_and_add_new_column)
272
+ layout.addWidget(self.submit_btn, 30)
273
+
274
+ self.setAttribute(Qt.WA_DeleteOnClose)
275
+
276
+
277
+ def compute_abs_and_add_new_column(self):
278
+
279
+
280
+ self.parent_window.data['|'+self.measurements_cb.currentText()+'|'] = self.parent_window.data[self.measurements_cb.currentText()].abs()
281
+ self.parent_window.model = PandasModel(self.parent_window.data)
282
+ self.parent_window.table_view.setModel(self.parent_window.model)
283
+ self.close()
240
284
 
241
285
  class RenameColWidget(QWidget):
242
286
 
@@ -275,6 +319,7 @@ class RenameColWidget(QWidget):
275
319
 
276
320
 
277
321
  class TableUI(QMainWindow, Styles):
322
+
278
323
  def __init__(self, data, title, population='targets',plot_mode="plot_track_signals", *args, **kwargs):
279
324
 
280
325
  QMainWindow.__init__(self, *args, **kwargs)
@@ -286,6 +331,16 @@ class TableUI(QMainWindow, Styles):
286
331
  self.plot_mode = plot_mode
287
332
  self.population = population
288
333
  self.numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
334
+ self.groupby_cols = ['position', 'TRACK_ID']
335
+ self.tracks = False
336
+
337
+ if self.population=='pairs':
338
+ self.groupby_cols = ['position','reference_population', 'neighbor_population','REFERENCE_ID', 'NEIGHBOR_ID']
339
+ self.tracks = True # for now
340
+ else:
341
+ if 'TRACK_ID' in data.columns:
342
+ if not np.all(data['TRACK_ID'].isnull()):
343
+ self.tracks = True
289
344
 
290
345
  self._createMenuBar()
291
346
  self._createActions()
@@ -328,6 +383,18 @@ class TableUI(QMainWindow, Styles):
328
383
  self.groupby_action.triggered.connect(self.set_projection_mode_tracks)
329
384
  self.groupby_action.setShortcut("Ctrl+g")
330
385
  self.fileMenu.addAction(self.groupby_action)
386
+ if not self.tracks:
387
+ self.groupby_action.setEnabled(False)
388
+
389
+ if self.population=='pairs':
390
+ self.groupby_neigh_action = QAction("&Group by neighbors...", self)
391
+ self.groupby_neigh_action.triggered.connect(self.set_projection_mode_neigh)
392
+ self.fileMenu.addAction(self.groupby_neigh_action)
393
+
394
+ self.groupby_ref_action = QAction("&Group by reference...", self)
395
+ self.groupby_ref_action.triggered.connect(self.set_projection_mode_ref)
396
+ self.fileMenu.addAction(self.groupby_ref_action)
397
+
331
398
 
332
399
  self.groupby_time_action = QAction("&Group by frames...", self)
333
400
  self.groupby_time_action.triggered.connect(self.groupby_time_table)
@@ -348,16 +415,76 @@ class TableUI(QMainWindow, Styles):
348
415
  #self.rename_col_action.setShortcut(Qt.Key_Delete)
349
416
  self.editMenu.addAction(self.rename_col_action)
350
417
 
418
+ if self.population=='pairs':
419
+ self.merge_action = QAction('&Merge...', self)
420
+ self.merge_action.triggered.connect(self.merge_tables)
421
+ #self.rename_col_action.setShortcut(Qt.Key_Delete)
422
+ self.editMenu.addAction(self.merge_action)
423
+
351
424
  self.derivative_action = QAction('&Differentiate...', self)
352
425
  self.derivative_action.triggered.connect(self.differenciate_selected_feature)
353
426
  self.derivative_action.setShortcut("Ctrl+D")
354
427
  self.mathMenu.addAction(self.derivative_action)
355
428
 
429
+ self.abs_action = QAction('&Absolute value...', self)
430
+ self.abs_action.triggered.connect(self.take_abs_of_selected_feature)
431
+ #self.derivative_action.setShortcut("Ctrl+D")
432
+ self.mathMenu.addAction(self.abs_action)
433
+
356
434
  self.onehot_action = QAction('&One hot to categorical...', self)
357
435
  self.onehot_action.triggered.connect(self.transform_one_hot_cols_to_categorical)
358
436
  #self.onehot_action.setShortcut("Ctrl+D")
359
437
  self.mathMenu.addAction(self.onehot_action)
360
438
 
439
+ def merge_tables(self):
440
+
441
+ expanded_table = []
442
+
443
+ for neigh, group in self.data.groupby(['reference_population','neighbor_population']):
444
+ print(f'{neigh=}')
445
+ ref_pop = neigh[0]; neigh_pop = neigh[1];
446
+ for pos,pos_group in group.groupby('position'):
447
+ print(f'{pos=}')
448
+
449
+ ref_tab = os.sep.join([pos,'output','tables',f'trajectories_{ref_pop}.csv'])
450
+ neigh_tab = os.sep.join([pos,'output','tables',f'trajectories_{neigh_pop}.csv'])
451
+ if os.path.exists(ref_tab):
452
+ df_ref = pd.read_csv(ref_tab)
453
+ if 'TRACK_ID' in df_ref.columns:
454
+ if not np.all(df_ref['TRACK_ID'].isnull()):
455
+ ref_merge_cols = ['TRACK_ID','FRAME']
456
+ else:
457
+ ref_merge_cols = ['ID','FRAME']
458
+ else:
459
+ ref_merge_cols = ['ID','FRAME']
460
+ if os.path.exists(neigh_tab):
461
+ df_neigh = pd.read_csv(neigh_tab)
462
+ if 'TRACK_ID' in df_neigh.columns:
463
+ if not np.all(df_neigh['TRACK_ID'].isnull()):
464
+ neigh_merge_cols = ['TRACK_ID','FRAME']
465
+ else:
466
+ neigh_merge_cols = ['ID','FRAME']
467
+ else:
468
+ neigh_merge_cols = ['ID','FRAME']
469
+
470
+ df_ref = df_ref.add_prefix('reference_',axis=1)
471
+ df_neigh = df_neigh.add_prefix('neighbor_',axis=1)
472
+ ref_merge_cols = ['reference_'+c for c in ref_merge_cols]
473
+ neigh_merge_cols = ['neighbor_'+c for c in neigh_merge_cols]
474
+
475
+ merge_ref = pos_group.merge(df_ref, how='outer', left_on=['REFERENCE_ID','FRAME'], right_on=ref_merge_cols, suffixes=('', '_reference'))
476
+ print(f'{merge_ref.columns=}')
477
+ merge_neigh = merge_ref.merge(df_neigh, how='outer', left_on=['NEIGHBOR_ID','FRAME'], right_on=neigh_merge_cols, suffixes=('_reference', '_neighbor'))
478
+ print(f'{merge_neigh.columns=}')
479
+ expanded_table.append(merge_neigh)
480
+
481
+ df_expanded = pd.concat(expanded_table, axis=0, ignore_index = True)
482
+ df_expanded = df_expanded.sort_values(by=['position', 'reference_population','neighbor_population','REFERENCE_ID','NEIGHBOR_ID','FRAME'])
483
+ df_expanded = df_expanded.dropna(axis=0, subset=['REFERENCE_ID','NEIGHBOR_ID','reference_population','neighbor_population'])
484
+ self.subtable = TableUI(df_expanded, 'merge', plot_mode = "static", population='pairs')
485
+ self.subtable.show()
486
+
487
+
361
488
  def delete_columns(self):
362
489
 
363
490
  x = self.table_view.selectedIndexes()
@@ -407,8 +534,6 @@ class TableUI(QMainWindow, Styles):
407
534
  pos_group.to_csv(pos+os.sep.join(['output', 'tables', f'trajectories_{self.population}.csv']), index=False)
408
535
  print("Done...")
409
536
 
410
-
411
-
412
537
  def differenciate_selected_feature(self):
413
538
 
414
539
  # check only one col selected and assert is numerical
@@ -426,6 +551,24 @@ class TableUI(QMainWindow, Styles):
426
551
  self.diffWidget = DifferentiateColWidget(self, selected_col)
427
552
  self.diffWidget.show()
428
553
 
554
+ def take_abs_of_selected_feature(self):
555
+
556
+ # check only one col selected and assert is numerical
557
+ # open widget to select window parameters, directionality
558
+ # create new col
559
+
560
+ x = self.table_view.selectedIndexes()
561
+ col_idx = np.unique(np.array([l.column() for l in x]))
562
+ if col_idx!=0:
563
+ cols = np.array(list(self.data.columns))
564
+ selected_col = str(cols[col_idx][0])
565
+ else:
566
+ selected_col = None
567
+
568
+ self.absWidget = AbsColWidget(self, selected_col)
569
+ self.absWidget.show()
570
+
571
+
429
572
  def transform_one_hot_cols_to_categorical(self):
430
573
 
431
574
  x = self.table_view.selectedIndexes()
@@ -450,7 +593,7 @@ class TableUI(QMainWindow, Styles):
450
593
 
451
594
  num_df = self.data.select_dtypes(include=self.numerics)
452
595
 
453
- timeseries = num_df.groupby("FRAME").mean().copy()
596
+ timeseries = num_df.groupby("FRAME").sum().copy()
454
597
  timeseries["timeline"] = timeseries.index
455
598
  self.subtable = TableUI(timeseries,"Group by frames", plot_mode="plot_timeseries")
456
599
  self.subtable.show()
@@ -472,6 +615,16 @@ class TableUI(QMainWindow, Styles):
472
615
  # self.subtable = TableUI(timeseries,"Group by frames", plot_mode="plot_timeseries")
473
616
  # self.subtable.show()
474
617
 
618
+ def set_projection_mode_neigh(self):
619
+
620
+ self.groupby_cols = ['position', 'reference_population', 'neighbor_population', 'NEIGHBOR_ID', 'FRAME']
621
+ self.set_projection_mode_tracks()
622
+
623
+ def set_projection_mode_ref(self):
624
+
625
+ self.groupby_cols = ['position', 'reference_population', 'neighbor_population', 'REFERENCE_ID', 'FRAME']
626
+ self.set_projection_mode_tracks()
627
+
475
628
  def set_projection_mode_tracks(self):
476
629
 
477
630
  self.projectionWidget = QWidget()
@@ -483,6 +636,7 @@ class TableUI(QMainWindow, Styles):
483
636
 
484
637
  self.projection_option = QRadioButton('global operation: ')
485
638
  self.projection_option.setToolTip('Collapse the cell track measurements with an operation over each track.')
639
+ self.projection_option.setChecked(True)
486
640
  self.projection_option.toggled.connect(self.enable_projection_options)
487
641
  self.projection_op_cb = QComboBox()
488
642
  self.projection_op_cb.addItems(['mean','median','min','max', 'prod', 'sum'])
@@ -696,19 +850,42 @@ class TableUI(QMainWindow, Styles):
696
850
  self.x = self.x_cb.currentText()
697
851
 
698
852
  legend=True
853
+
699
854
  if self.hist_check.isChecked():
700
- sns.histplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, kde=True, common_norm=False, stat='density')
701
- legend = False
855
+ if self.x is not None:
856
+ sns.histplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, kde=True, common_norm=False, stat='density')
857
+ legend = False
858
+ elif self.x is None and self.y is not None:
859
+ sns.histplot(data=self.data, x=self.y, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, kde=True, common_norm=False, stat='density')
860
+ legend = False
861
+ else:
862
+ pass
863
+
702
864
  if self.kde_check.isChecked():
703
- sns.kdeplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, cut=0)
704
- legend = False
865
+ if self.x is not None:
866
+ sns.kdeplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, cut=0)
867
+ legend = False
868
+ elif self.x is None and self.y is not None:
869
+ sns.kdeplot(data=self.data, x=self.y, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, cut=0)
870
+ legend = False
871
+ else:
872
+ pass
873
+
705
874
  if self.count_check.isChecked():
706
875
  sns.countplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors)
707
876
  legend = False
877
+
878
+
708
879
  if self.ecdf_check.isChecked():
709
- sns.ecdfplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors)
710
- legend = False
711
-
880
+ if self.x is not None:
881
+ sns.ecdfplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors)
882
+ legend = False
883
+ elif self.x is None and self.y is not None:
884
+ sns.ecdfplot(data=self.data, x=self.y, hue=hue_variable, legend=legend, ax=self.ax, palette=colors)
885
+ legend = False
886
+ else:
887
+ pass
888
+
712
889
  if self.scat_check.isChecked():
713
890
  if self.x_option:
714
891
  sns.scatterplot(data=self.data, x=self.x,y=self.y, hue=hue_variable,legend=legend, ax=self.ax, palette=colors)
@@ -727,7 +904,7 @@ class TableUI(QMainWindow, Styles):
727
904
 
728
905
  if self.violin_check.isChecked():
729
906
  if self.x_option:
730
- sns.stripplot(data=self.data,x=self.x, y=self.y,dodge=True, ax=self.ax, hue=hue_variable, legend=legend, palette=colors)
907
+ sns.violinplot(data=self.data,x=self.x, y=self.y,dodge=True, ax=self.ax, hue=hue_variable, legend=legend, palette=colors)
731
908
  legend = False
732
909
  else:
733
910
  sns.violinplot(data=self.data, y=self.y,dodge=True, hue=hue_variable,legend=legend, ax=self.ax, palette=colors, cut=0)
@@ -766,32 +943,39 @@ class TableUI(QMainWindow, Styles):
766
943
 
767
944
  def set_proj_mode(self):
768
945
 
769
- self.static_columns = ['well_index', 'well_name', 'pos_name', 'position', 'well', 'status', 't0', 'class','cell_type','concentration', 'antibody', 'pharmaceutical_agent','TRACK_ID','position']
946
+ self.static_columns = ['well_index', 'well_name', 'pos_name', 'position', 'well', 'status', 't0', 'class','cell_type','concentration', 'antibody', 'pharmaceutical_agent','TRACK_ID','position', 'neighbor_population', 'reference_population', 'NEIGHBOR_ID', 'REFERENCE_ID', 'FRAME']
770
947
 
771
948
  if self.projection_option.isChecked():
772
949
 
773
950
  self.projection_mode = self.projection_op_cb.currentText()
774
- op = getattr(self.data.groupby(['position', 'TRACK_ID']), self.projection_mode)
775
- group_table = op(self.data.groupby(['position', 'TRACK_ID']))
951
+ op = getattr(self.data.groupby(self.groupby_cols), self.projection_mode)
952
+ group_table = op(self.data.groupby(self.groupby_cols))
776
953
 
777
954
  for c in self.static_columns:
778
955
  try:
779
- group_table[c] = self.data.groupby(['position','TRACK_ID'])[c].apply(lambda x: x.unique()[0])
956
+ group_table[c] = self.data.groupby(self.groupby_cols)[c].apply(lambda x: x.unique()[0])
780
957
  except Exception as e:
781
958
  print(e)
782
959
  pass
783
960
 
784
- for col in ['TRACK_ID']:
785
- first_column = group_table.pop(col)
786
- group_table.insert(0, col, first_column)
787
- group_table.pop('FRAME')
961
+ if self.population=='pairs':
962
+ for col in reversed(self.groupby_cols): #['neighbor_population', 'reference_population', 'NEIGHBOR_ID', 'REFERENCE_ID']
963
+ if col in group_table:
964
+ first_column = group_table.pop(col)
965
+ group_table.insert(0, col, first_column)
966
+ else:
967
+ for col in ['TRACK_ID']:
968
+ first_column = group_table.pop(col)
969
+ group_table.insert(0, col, first_column)
970
+ group_table.pop('FRAME')
788
971
 
789
972
 
790
973
  elif self.event_time_option.isChecked():
974
+
791
975
  time_of_interest = self.event_times_cb.currentText()
792
976
  self.projection_mode = f"measurements at {time_of_interest}"
793
977
  new_table = []
794
- for tid,group in self.data.groupby(['position','TRACK_ID']):
978
+ for tid,group in self.data.groupby(self.groupby_cols):
795
979
  time = group[time_of_interest].values[0]
796
980
  if time==time:
797
981
  time = floor(time) # floor for onset
@@ -801,49 +985,27 @@ class TableUI(QMainWindow, Styles):
801
985
  values = group.loc[group['FRAME']==time,:].to_numpy()
802
986
  if len(values)>0:
803
987
  values = dict(zip(list(self.data.columns), values[0]))
804
- values.update({'TRACK_ID': tid[1]})
805
- values.update({'position': tid[0]})
988
+ for k,c in enumerate(self.groupby_cols):
989
+ values.update({c: tid[k]})
806
990
  new_table.append(values)
807
991
 
808
992
  group_table = pd.DataFrame(new_table)
809
- for col in ['TRACK_ID']:
810
- first_column = group_table.pop(col)
811
- group_table.insert(0, col, first_column)
812
-
813
- group_table = group_table.sort_values(by=['position','TRACK_ID','FRAME'],ignore_index=True)
993
+ if self.population=='pairs':
994
+ for col in self.groupby_cols[1:]:
995
+ first_column = group_table.pop(col)
996
+ group_table.insert(0, col, first_column)
997
+ else:
998
+ for col in ['TRACK_ID']:
999
+ first_column = group_table.pop(col)
1000
+ group_table.insert(0, col, first_column)
1001
+
1002
+ group_table = group_table.sort_values(by=self.groupby_cols+['FRAME'],ignore_index=True)
814
1003
  group_table = group_table.reset_index(drop=True)
815
1004
 
816
1005
 
817
1006
  elif self.per_status_option.isChecked():
818
1007
 
819
- status_of_interest = self.per_status_cb.currentText()
820
- self.projection_mode = f'{self.status_operation.currentText()} per {status_of_interest}'
821
- self.data = self.data.dropna(subset=status_of_interest,ignore_index=True)
822
- unique_statuses = np.unique(self.data[status_of_interest].to_numpy())
823
-
824
- df_sections = []
825
- for s in unique_statuses:
826
- subtab = self.data.loc[self.data[status_of_interest]==s,:]
827
- op = getattr(subtab.groupby(['position', 'TRACK_ID']), self.status_operation.currentText())
828
- subtab_projected = op(subtab.groupby(['position', 'TRACK_ID']))
829
- frame_duration = subtab.groupby(['position','TRACK_ID']).size().to_numpy()
830
- for c in self.static_columns:
831
- try:
832
- subtab_projected[c] = subtab.groupby(['position', 'TRACK_ID'])[c].apply(lambda x: x.unique()[0])
833
- except Exception as e:
834
- print(e)
835
- pass
836
- subtab_projected['duration_in_state'] = frame_duration
837
- df_sections.append(subtab_projected)
838
-
839
- group_table = pd.concat(df_sections,axis=0,ignore_index=True)
840
- for col in ['duration_in_state',status_of_interest,'TRACK_ID']:
841
- first_column = group_table.pop(col)
842
- group_table.insert(0, col, first_column)
843
- group_table.pop('FRAME')
844
- group_table = group_table.sort_values(by=['position','TRACK_ID',status_of_interest],ignore_index=True)
845
- group_table = group_table.reset_index(drop=True)
846
-
1008
+ group_table = collapse_trajectories_by_status(self.data, status=self.per_status_cb.currentText(),population=self.population, projection=self.status_operation.currentText(), groupby_columns=self.groupby_cols)
847
1009
 
848
1010
  self.subtable = TableUI(group_table,f"Group by tracks: {self.projection_mode}", plot_mode="static")
849
1011
  self.subtable.show()
@@ -972,7 +1134,7 @@ class TableUI(QMainWindow, Styles):
972
1134
  print(unique_cols[k])
973
1135
  for w,well_group in self.data.groupby('well_name'):
974
1136
  for pos,pos_group in well_group.groupby('pos_name'):
975
- for tid,group_track in pos_group.groupby('TRACK_ID'):
1137
+ for tid,group_track in pos_group.groupby(self.groupby_cols[1:]):
976
1138
  ax.plot(group_track["FRAME"], group_track[column_names[unique_cols[k]]],label=column_names[unique_cols[k]])
977
1139
  #ax.plot(self.data["FRAME"][row_idx_i], y, label=column_names[unique_cols[k]])
978
1140
  ax.legend()
@@ -986,7 +1148,7 @@ class TableUI(QMainWindow, Styles):
986
1148
  self.fig, self.ax = plt.subplots(1, 1, figsize=(4, 3))
987
1149
  self.scatter_wdw = FigureCanvas(self.fig, title="scatter")
988
1150
  self.ax.clear()
989
- for tid,group in self.data.groupby('TRACK_ID'):
1151
+ for tid,group in self.data.groupby(self.groupby_cols[1:]):
990
1152
  self.ax.plot(group[column_names[unique_cols[0]]], group[column_names[unique_cols[1]]], marker="o")
991
1153
  self.ax.set_xlabel(column_names[unique_cols[0]])
992
1154
  self.ax.set_ylabel(column_names[unique_cols[1]])
@@ -1009,7 +1171,7 @@ class TableUI(QMainWindow, Styles):
1009
1171
 
1010
1172
  for w,well_group in self.data.groupby('well_name'):
1011
1173
  for pos,pos_group in well_group.groupby('pos_name'):
1012
- for tid,group_track in pos_group.groupby('TRACK_ID'):
1174
+ for tid,group_track in pos_group.groupby(self.groupby_cols[1:]):
1013
1175
  self.ax.plot(group_track["FRAME"], group_track[column_names[unique_cols[0]]],c="k", alpha = 0.1)
1014
1176
  self.ax.set_xlabel(r"$t$ [frame]")
1015
1177
  self.ax.set_ylabel(column_names[unique_cols[0]])