celldetective 1.1.1.post4__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. celldetective/__init__.py +2 -1
  2. celldetective/extra_properties.py +62 -34
  3. celldetective/gui/__init__.py +1 -0
  4. celldetective/gui/analyze_block.py +2 -1
  5. celldetective/gui/classifier_widget.py +8 -7
  6. celldetective/gui/control_panel.py +50 -6
  7. celldetective/gui/layouts.py +5 -4
  8. celldetective/gui/neighborhood_options.py +10 -8
  9. celldetective/gui/plot_signals_ui.py +39 -11
  10. celldetective/gui/process_block.py +413 -95
  11. celldetective/gui/retrain_segmentation_model_options.py +17 -4
  12. celldetective/gui/retrain_signal_model_options.py +106 -6
  13. celldetective/gui/signal_annotator.py +25 -5
  14. celldetective/gui/signal_annotator2.py +2708 -0
  15. celldetective/gui/signal_annotator_options.py +3 -1
  16. celldetective/gui/survival_ui.py +15 -6
  17. celldetective/gui/tableUI.py +235 -39
  18. celldetective/io.py +537 -421
  19. celldetective/measure.py +919 -969
  20. celldetective/models/pair_signal_detection/blank +0 -0
  21. celldetective/neighborhood.py +426 -354
  22. celldetective/relative_measurements.py +648 -0
  23. celldetective/scripts/analyze_signals.py +1 -1
  24. celldetective/scripts/measure_cells.py +28 -8
  25. celldetective/scripts/measure_relative.py +103 -0
  26. celldetective/scripts/segment_cells.py +5 -5
  27. celldetective/scripts/track_cells.py +4 -1
  28. celldetective/scripts/train_segmentation_model.py +23 -18
  29. celldetective/scripts/train_signal_model.py +33 -0
  30. celldetective/signals.py +402 -8
  31. celldetective/tracking.py +8 -2
  32. celldetective/utils.py +93 -0
  33. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/METADATA +8 -8
  34. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/RECORD +38 -34
  35. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/WHEEL +1 -1
  36. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/LICENSE +0 -0
  37. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/entry_points.txt +0 -0
  38. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,9 @@ class ConfigSignalAnnotator(QMainWindow, Styles):
34
34
  self.instructions_path = self.parent_window.exp_dir + "configs/signal_annotator_config_targets.json"
35
35
  elif self.mode=="effectors":
36
36
  self.instructions_path = self.parent_window.exp_dir + "configs/signal_annotator_config_effectors.json"
37
-
37
+ elif self.mode == "neighborhood":
38
+ self.instructions_path = self.parent_window.exp_dir + "configs/signal_annotator_config_neighborhood.json"
39
+
38
40
  exp_config = self.exp_dir +"config.ini"
39
41
  #self.config_path = self.exp_dir + self.config_name
40
42
  self.channel_names, self.channels = extract_experiment_channels(exp_config)
@@ -109,9 +109,10 @@ class ConfigSurvival(QWidget, Styles):
109
109
  main_layout.addWidget(panel_title, alignment=Qt.AlignCenter)
110
110
 
111
111
 
112
- labels = [QLabel('population: '), QLabel('time of\nreference: '), QLabel('time of\ninterest: '), QLabel('exclude\nclass: '), QLabel('cmap: ')] #QLabel('class: '),
113
- self.cb_options = [['targets','effectors'], ['0','t_firstdetection'], ['t0'], ['--'], list(plt.colormaps())] #['class'],
112
+ labels = [QLabel('population: '), QLabel('time of\nreference: '), QLabel('time of\ninterest: '), QLabel('cmap: ')] #QLabel('class: '),
113
+ self.cb_options = [['targets','effectors'], ['0'], [], list(plt.colormaps())] #['class'],
114
114
  self.cbs = [QComboBox() for i in range(len(labels))]
115
+
115
116
  self.cbs[-1] = QColormapComboBox()
116
117
  self.cbs[0].currentIndexChanged.connect(self.set_classes_and_times)
117
118
 
@@ -133,6 +134,12 @@ class ConfigSurvival(QWidget, Styles):
133
134
 
134
135
  main_layout.addLayout(choice_layout)
135
136
 
137
+ select_layout = QHBoxLayout()
138
+ select_layout.addWidget(QLabel('select cells\nwith query: '), 33)
139
+ self.query_le = QLineEdit()
140
+ select_layout.addWidget(self.query_le, 66)
141
+ main_layout.addLayout(select_layout)
142
+
136
143
  self.cbs[0].setCurrentIndex(0)
137
144
  self.cbs[1].setCurrentText('t_firstdetection')
138
145
 
@@ -218,10 +225,12 @@ class ConfigSurvival(QWidget, Styles):
218
225
 
219
226
  if self.df is not None:
220
227
 
221
- excluded_class = self.cbs[3].currentText()
222
- if excluded_class!='--':
223
- print(f"Excluding {excluded_class}...")
224
- self.df = self.df.loc[~(self.df[excluded_class].isin([0,2])),:]
228
+ try:
229
+ query_text = self.query_le.text()
230
+ if query_text != '':
231
+ self.df = self.df.query(query_text)
232
+ except Exception as e:
233
+ print(e, ' The query is misunderstood and will not be applied...')
225
234
 
226
235
  self.compute_survival_functions()
227
236
  # prepare survival
@@ -20,6 +20,10 @@ from matplotlib import colormaps
20
20
 
21
21
  class PandasModel(QAbstractTableModel):
22
22
 
23
+ """
24
+ from https://stackoverflow.com/questions/31475965/fastest-way-to-populate-qtableview-from-pandas-data-frame
25
+ """
26
+
23
27
  def __init__(self, data):
24
28
  QAbstractTableModel.__init__(self)
25
29
  self._data = data
@@ -67,7 +71,7 @@ class QueryWidget(QWidget):
67
71
  try:
68
72
  query_text = self.query_le.text() #.replace('class', '`class`')
69
73
  tab = self.parent_window.data.query(query_text)
70
- self.subtable = TableUI(tab, query_text, plot_mode="static")
74
+ self.subtable = TableUI(tab, query_text, plot_mode="static", population=self.parent_window.population)
71
75
  self.subtable.show()
72
76
  self.close()
73
77
  except Exception as e:
@@ -236,7 +240,47 @@ class DifferentiateColWidget(QWidget, Styles):
236
240
  self.parent_window.table_view.setModel(self.parent_window.model)
237
241
  self.close()
238
242
 
243
+ class AbsColWidget(QWidget, Styles):
244
+
245
+ def __init__(self, parent_window, column=None):
246
+
247
+ super().__init__()
248
+ self.parent_window = parent_window
249
+ self.column = column
250
+
251
+ self.setWindowTitle("abs(.)")
252
+ # Create the QComboBox and add some items
253
+ center_window(self)
254
+
255
+ layout = QVBoxLayout(self)
256
+ layout.setContentsMargins(30,30,30,30)
257
+
258
+ self.measurements_cb = QComboBox()
259
+ self.measurements_cb.addItems(list(self.parent_window.data.columns))
260
+ if self.column is not None:
261
+ idx = self.measurements_cb.findText(self.column)
262
+ self.measurements_cb.setCurrentIndex(idx)
263
+
264
+ measurement_layout = QHBoxLayout()
265
+ measurement_layout.addWidget(QLabel('measurements: '), 25)
266
+ measurement_layout.addWidget(self.measurements_cb, 75)
267
+ layout.addLayout(measurement_layout)
268
+
269
+ self.submit_btn = QPushButton('Compute')
270
+ self.submit_btn.setStyleSheet(self.button_style_sheet)
271
+ self.submit_btn.clicked.connect(self.compute_abs_and_add_new_column)
272
+ layout.addWidget(self.submit_btn, 30)
239
273
 
274
+ self.setAttribute(Qt.WA_DeleteOnClose)
275
+
276
+
277
+ def compute_abs_and_add_new_column(self):
278
+
279
+
280
+ self.parent_window.data['|'+self.measurements_cb.currentText()+'|'] = self.parent_window.data[self.measurements_cb.currentText()].abs()
281
+ self.parent_window.model = PandasModel(self.parent_window.data)
282
+ self.parent_window.table_view.setModel(self.parent_window.model)
283
+ self.close()
240
284
 
241
285
  class RenameColWidget(QWidget):
242
286
 
@@ -275,6 +319,7 @@ class RenameColWidget(QWidget):
275
319
 
276
320
 
277
321
  class TableUI(QMainWindow, Styles):
322
+
278
323
  def __init__(self, data, title, population='targets',plot_mode="plot_track_signals", *args, **kwargs):
279
324
 
280
325
  QMainWindow.__init__(self, *args, **kwargs)
@@ -286,6 +331,16 @@ class TableUI(QMainWindow, Styles):
286
331
  self.plot_mode = plot_mode
287
332
  self.population = population
288
333
  self.numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
334
+ self.groupby_cols = ['position', 'TRACK_ID']
335
+ self.tracks = False
336
+
337
+ if self.population=='pairs':
338
+ self.groupby_cols = ['position','reference_population', 'neighbor_population','REFERENCE_ID', 'NEIGHBOR_ID', 'FRAME']
339
+ self.tracks = True # for now
340
+ else:
341
+ if 'TRACK_ID' in data.columns:
342
+ if not np.all(data['TRACK_ID'].isnull()):
343
+ self.tracks = True
289
344
 
290
345
  self._createMenuBar()
291
346
  self._createActions()
@@ -328,6 +383,18 @@ class TableUI(QMainWindow, Styles):
328
383
  self.groupby_action.triggered.connect(self.set_projection_mode_tracks)
329
384
  self.groupby_action.setShortcut("Ctrl+g")
330
385
  self.fileMenu.addAction(self.groupby_action)
386
+ if not self.tracks:
387
+ self.groupby_action.setEnabled(False)
388
+
389
+ if self.population=='pairs':
390
+ self.groupby_neigh_action = QAction("&Group by neighbors...", self)
391
+ self.groupby_neigh_action.triggered.connect(self.set_projection_mode_neigh)
392
+ self.fileMenu.addAction(self.groupby_neigh_action)
393
+
394
+ self.groupby_ref_action = QAction("&Group by reference...", self)
395
+ self.groupby_ref_action.triggered.connect(self.set_projection_mode_ref)
396
+ self.fileMenu.addAction(self.groupby_ref_action)
397
+
331
398
 
332
399
  self.groupby_time_action = QAction("&Group by frames...", self)
333
400
  self.groupby_time_action.triggered.connect(self.groupby_time_table)
@@ -348,16 +415,76 @@ class TableUI(QMainWindow, Styles):
348
415
  #self.rename_col_action.setShortcut(Qt.Key_Delete)
349
416
  self.editMenu.addAction(self.rename_col_action)
350
417
 
418
+ if self.population=='pairs':
419
+ self.merge_action = QAction('&Merge...', self)
420
+ self.merge_action.triggered.connect(self.merge_tables)
421
+ #self.rename_col_action.setShortcut(Qt.Key_Delete)
422
+ self.editMenu.addAction(self.merge_action)
423
+
351
424
  self.derivative_action = QAction('&Differentiate...', self)
352
425
  self.derivative_action.triggered.connect(self.differenciate_selected_feature)
353
426
  self.derivative_action.setShortcut("Ctrl+D")
354
427
  self.mathMenu.addAction(self.derivative_action)
355
428
 
429
+ self.abs_action = QAction('&Absolute value...', self)
430
+ self.abs_action.triggered.connect(self.take_abs_of_selected_feature)
431
+ #self.derivative_action.setShortcut("Ctrl+D")
432
+ self.mathMenu.addAction(self.abs_action)
433
+
356
434
  self.onehot_action = QAction('&One hot to categorical...', self)
357
435
  self.onehot_action.triggered.connect(self.transform_one_hot_cols_to_categorical)
358
436
  #self.onehot_action.setShortcut("Ctrl+D")
359
437
  self.mathMenu.addAction(self.onehot_action)
360
438
 
439
+ def merge_tables(self):
440
+
441
+ expanded_table = []
442
+
443
+ for neigh, group in self.data.groupby(['reference_population','neighbor_population']):
444
+ print(f'{neigh=}')
445
+ ref_pop = neigh[0]; neigh_pop = neigh[1];
446
+ for pos,pos_group in group.groupby('position'):
447
+ print(f'{pos=}')
448
+
449
+ ref_tab = os.sep.join([pos,'output','tables',f'trajectories_{ref_pop}.csv'])
450
+ neigh_tab = os.sep.join([pos,'output','tables',f'trajectories_{neigh_pop}.csv'])
451
+ if os.path.exists(ref_tab):
452
+ df_ref = pd.read_csv(ref_tab)
453
+ if 'TRACK_ID' in df_ref.columns:
454
+ if not np.all(df_ref['TRACK_ID'].isnull()):
455
+ ref_merge_cols = ['TRACK_ID','FRAME']
456
+ else:
457
+ ref_merge_cols = ['ID','FRAME']
458
+ else:
459
+ ref_merge_cols = ['ID','FRAME']
460
+ if os.path.exists(neigh_tab):
461
+ df_neigh = pd.read_csv(neigh_tab)
462
+ if 'TRACK_ID' in df_neigh.columns:
463
+ if not np.all(df_neigh['TRACK_ID'].isnull()):
464
+ neigh_merge_cols = ['TRACK_ID','FRAME']
465
+ else:
466
+ neigh_merge_cols = ['ID','FRAME']
467
+ else:
468
+ neigh_merge_cols = ['ID','FRAME']
469
+
470
+ df_ref = df_ref.add_prefix('reference_',axis=1)
471
+ df_neigh = df_neigh.add_prefix('neighbor_',axis=1)
472
+ ref_merge_cols = ['reference_'+c for c in ref_merge_cols]
473
+ neigh_merge_cols = ['neighbor_'+c for c in neigh_merge_cols]
474
+
475
+ merge_ref = pos_group.merge(df_ref, how='outer', left_on=['REFERENCE_ID','FRAME'], right_on=ref_merge_cols, suffixes=('', '_reference'))
476
+ print(f'{merge_ref.columns=}')
477
+ merge_neigh = merge_ref.merge(df_neigh, how='outer', left_on=['NEIGHBOR_ID','FRAME'], right_on=neigh_merge_cols, suffixes=('_reference', '_neighbor'))
478
+ print(f'{merge_neigh.columns=}')
479
+ expanded_table.append(merge_neigh)
480
+
481
+ df_expanded = pd.concat(expanded_table, axis=0, ignore_index = True)
482
+ df_expanded = df_expanded.sort_values(by=['position', 'reference_population','neighbor_population','REFERENCE_ID','NEIGHBOR_ID','FRAME'])
483
+ df_expanded = df_expanded.dropna(axis=0, subset=['REFERENCE_ID','NEIGHBOR_ID','reference_population','neighbor_population'])
484
+ self.subtable = TableUI(df_expanded, 'merge', plot_mode = "static", population='pairs')
485
+ self.subtable.show()
486
+
487
+
361
488
  def delete_columns(self):
362
489
 
363
490
  x = self.table_view.selectedIndexes()
@@ -407,8 +534,6 @@ class TableUI(QMainWindow, Styles):
407
534
  pos_group.to_csv(pos+os.sep.join(['output', 'tables', f'trajectories_{self.population}.csv']), index=False)
408
535
  print("Done...")
409
536
 
410
-
411
-
412
537
  def differenciate_selected_feature(self):
413
538
 
414
539
  # check only one col selected and assert is numerical
@@ -426,6 +551,24 @@ class TableUI(QMainWindow, Styles):
426
551
  self.diffWidget = DifferentiateColWidget(self, selected_col)
427
552
  self.diffWidget.show()
428
553
 
554
+ def take_abs_of_selected_feature(self):
555
+
556
+ # check only one col selected and assert is numerical
557
+ # open widget to select window parameters, directionality
558
+ # create new col
559
+
560
+ x = self.table_view.selectedIndexes()
561
+ col_idx = np.unique(np.array([l.column() for l in x]))
562
+ if col_idx!=0:
563
+ cols = np.array(list(self.data.columns))
564
+ selected_col = str(cols[col_idx][0])
565
+ else:
566
+ selected_col = None
567
+
568
+ self.absWidget = AbsColWidget(self, selected_col)
569
+ self.absWidget.show()
570
+
571
+
429
572
  def transform_one_hot_cols_to_categorical(self):
430
573
 
431
574
  x = self.table_view.selectedIndexes()
@@ -450,7 +593,7 @@ class TableUI(QMainWindow, Styles):
450
593
 
451
594
  num_df = self.data.select_dtypes(include=self.numerics)
452
595
 
453
- timeseries = num_df.groupby("FRAME").mean().copy()
596
+ timeseries = num_df.groupby("FRAME").sum().copy()
454
597
  timeseries["timeline"] = timeseries.index
455
598
  self.subtable = TableUI(timeseries,"Group by frames", plot_mode="plot_timeseries")
456
599
  self.subtable.show()
@@ -472,6 +615,16 @@ class TableUI(QMainWindow, Styles):
472
615
  # self.subtable = TableUI(timeseries,"Group by frames", plot_mode="plot_timeseries")
473
616
  # self.subtable.show()
474
617
 
618
+ def set_projection_mode_neigh(self):
619
+
620
+ self.groupby_cols = ['position', 'reference_population', 'neighbor_population', 'NEIGHBOR_ID', 'FRAME']
621
+ self.set_projection_mode_tracks()
622
+
623
+ def set_projection_mode_ref(self):
624
+
625
+ self.groupby_cols = ['position', 'reference_population', 'neighbor_population', 'REFERENCE_ID', 'FRAME']
626
+ self.set_projection_mode_tracks()
627
+
475
628
  def set_projection_mode_tracks(self):
476
629
 
477
630
  self.projectionWidget = QWidget()
@@ -483,6 +636,7 @@ class TableUI(QMainWindow, Styles):
483
636
 
484
637
  self.projection_option = QRadioButton('global operation: ')
485
638
  self.projection_option.setToolTip('Collapse the cell track measurements with an operation over each track.')
639
+ self.projection_option.setChecked(True)
486
640
  self.projection_option.toggled.connect(self.enable_projection_options)
487
641
  self.projection_op_cb = QComboBox()
488
642
  self.projection_op_cb.addItems(['mean','median','min','max', 'prod', 'sum'])
@@ -696,19 +850,42 @@ class TableUI(QMainWindow, Styles):
696
850
  self.x = self.x_cb.currentText()
697
851
 
698
852
  legend=True
853
+
699
854
  if self.hist_check.isChecked():
700
- sns.histplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, kde=True, common_norm=False, stat='density')
701
- legend = False
855
+ if self.x is not None:
856
+ sns.histplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, kde=True, common_norm=False, stat='density')
857
+ legend = False
858
+ elif self.x is None and self.y is not None:
859
+ sns.histplot(data=self.data, x=self.y, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, kde=True, common_norm=False, stat='density')
860
+ legend = False
861
+ else:
862
+ pass
863
+
702
864
  if self.kde_check.isChecked():
703
- sns.kdeplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, cut=0)
704
- legend = False
865
+ if self.x is not None:
866
+ sns.kdeplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, cut=0)
867
+ legend = False
868
+ elif self.x is None and self.y is not None:
869
+ sns.kdeplot(data=self.data, x=self.y, hue=hue_variable, legend=legend, ax=self.ax, palette=colors, cut=0)
870
+ legend = False
871
+ else:
872
+ pass
873
+
705
874
  if self.count_check.isChecked():
706
875
  sns.countplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors)
707
876
  legend = False
877
+
878
+
708
879
  if self.ecdf_check.isChecked():
709
- sns.ecdfplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors)
710
- legend = False
711
-
880
+ if self.x is not None:
881
+ sns.ecdfplot(data=self.data, x=self.x, hue=hue_variable, legend=legend, ax=self.ax, palette=colors)
882
+ legend = False
883
+ elif self.x is None and self.y is not None:
884
+ sns.ecdfplot(data=self.data, x=self.y, hue=hue_variable, legend=legend, ax=self.ax, palette=colors)
885
+ legend = False
886
+ else:
887
+ pass
888
+
712
889
  if self.scat_check.isChecked():
713
890
  if self.x_option:
714
891
  sns.scatterplot(data=self.data, x=self.x,y=self.y, hue=hue_variable,legend=legend, ax=self.ax, palette=colors)
@@ -727,7 +904,7 @@ class TableUI(QMainWindow, Styles):
727
904
 
728
905
  if self.violin_check.isChecked():
729
906
  if self.x_option:
730
- sns.stripplot(data=self.data,x=self.x, y=self.y,dodge=True, ax=self.ax, hue=hue_variable, legend=legend, palette=colors)
907
+ sns.violinplot(data=self.data,x=self.x, y=self.y,dodge=True, ax=self.ax, hue=hue_variable, legend=legend, palette=colors)
731
908
  legend = False
732
909
  else:
733
910
  sns.violinplot(data=self.data, y=self.y,dodge=True, hue=hue_variable,legend=legend, ax=self.ax, palette=colors, cut=0)
@@ -766,32 +943,39 @@ class TableUI(QMainWindow, Styles):
766
943
 
767
944
  def set_proj_mode(self):
768
945
 
769
- self.static_columns = ['well_index', 'well_name', 'pos_name', 'position', 'well', 'status', 't0', 'class','cell_type','concentration', 'antibody', 'pharmaceutical_agent','TRACK_ID','position']
946
+ self.static_columns = ['well_index', 'well_name', 'pos_name', 'position', 'well', 'status', 't0', 'class','cell_type','concentration', 'antibody', 'pharmaceutical_agent','TRACK_ID','position', 'neighbor_population', 'reference_population', 'NEIGHBOR_ID', 'REFERENCE_ID', 'FRAME']
770
947
 
771
948
  if self.projection_option.isChecked():
772
949
 
773
950
  self.projection_mode = self.projection_op_cb.currentText()
774
- op = getattr(self.data.groupby(['position', 'TRACK_ID']), self.projection_mode)
775
- group_table = op(self.data.groupby(['position', 'TRACK_ID']))
951
+ op = getattr(self.data.groupby(self.groupby_cols), self.projection_mode)
952
+ group_table = op(self.data.groupby(self.groupby_cols))
776
953
 
777
954
  for c in self.static_columns:
778
955
  try:
779
- group_table[c] = self.data.groupby(['position','TRACK_ID'])[c].apply(lambda x: x.unique()[0])
956
+ group_table[c] = self.data.groupby(self.groupby_cols)[c].apply(lambda x: x.unique()[0])
780
957
  except Exception as e:
781
958
  print(e)
782
959
  pass
783
960
 
784
- for col in ['TRACK_ID']:
785
- first_column = group_table.pop(col)
786
- group_table.insert(0, col, first_column)
787
- group_table.pop('FRAME')
961
+ if self.population=='pairs':
962
+ for col in self.groupby_cols[1:]: #['neighbor_population', 'reference_population', 'NEIGHBOR_ID', 'REFERENCE_ID']
963
+ if col in group_table:
964
+ first_column = group_table.pop(col)
965
+ group_table.insert(0, col, first_column)
966
+ else:
967
+ for col in ['TRACK_ID']:
968
+ first_column = group_table.pop(col)
969
+ group_table.insert(0, col, first_column)
970
+ group_table.pop('FRAME')
788
971
 
789
972
 
790
973
  elif self.event_time_option.isChecked():
974
+
791
975
  time_of_interest = self.event_times_cb.currentText()
792
976
  self.projection_mode = f"measurements at {time_of_interest}"
793
977
  new_table = []
794
- for tid,group in self.data.groupby(['position','TRACK_ID']):
978
+ for tid,group in self.data.groupby(self.groupby_cols):
795
979
  time = group[time_of_interest].values[0]
796
980
  if time==time:
797
981
  time = floor(time) # floor for onset
@@ -801,16 +985,21 @@ class TableUI(QMainWindow, Styles):
801
985
  values = group.loc[group['FRAME']==time,:].to_numpy()
802
986
  if len(values)>0:
803
987
  values = dict(zip(list(self.data.columns), values[0]))
804
- values.update({'TRACK_ID': tid[1]})
805
- values.update({'position': tid[0]})
988
+ for k,c in enumerate(self.groupby_cols):
989
+ values.update({c: tid[k]})
806
990
  new_table.append(values)
807
991
 
808
992
  group_table = pd.DataFrame(new_table)
809
- for col in ['TRACK_ID']:
810
- first_column = group_table.pop(col)
811
- group_table.insert(0, col, first_column)
812
-
813
- group_table = group_table.sort_values(by=['position','TRACK_ID','FRAME'],ignore_index=True)
993
+ if self.population=='pairs':
994
+ for col in self.groupby_cols[1:]:
995
+ first_column = group_table.pop(col)
996
+ group_table.insert(0, col, first_column)
997
+ else:
998
+ for col in ['TRACK_ID']:
999
+ first_column = group_table.pop(col)
1000
+ group_table.insert(0, col, first_column)
1001
+
1002
+ group_table = group_table.sort_values(by=self.groupby_cols+['FRAME'],ignore_index=True)
814
1003
  group_table = group_table.reset_index(drop=True)
815
1004
 
816
1005
 
@@ -824,12 +1013,12 @@ class TableUI(QMainWindow, Styles):
824
1013
  df_sections = []
825
1014
  for s in unique_statuses:
826
1015
  subtab = self.data.loc[self.data[status_of_interest]==s,:]
827
- op = getattr(subtab.groupby(['position', 'TRACK_ID']), self.status_operation.currentText())
828
- subtab_projected = op(subtab.groupby(['position', 'TRACK_ID']))
829
- frame_duration = subtab.groupby(['position','TRACK_ID']).size().to_numpy()
1016
+ op = getattr(subtab.groupby(self.groupby_cols), self.status_operation.currentText())
1017
+ subtab_projected = op(subtab.groupby(self.groupby_cols))
1018
+ frame_duration = subtab.groupby(self.groupby_cols).size().to_numpy()
830
1019
  for c in self.static_columns:
831
1020
  try:
832
- subtab_projected[c] = subtab.groupby(['position', 'TRACK_ID'])[c].apply(lambda x: x.unique()[0])
1021
+ subtab_projected[c] = subtab.groupby(self.groupby_cols)[c].apply(lambda x: x.unique()[0])
833
1022
  except Exception as e:
834
1023
  print(e)
835
1024
  pass
@@ -837,11 +1026,18 @@ class TableUI(QMainWindow, Styles):
837
1026
  df_sections.append(subtab_projected)
838
1027
 
839
1028
  group_table = pd.concat(df_sections,axis=0,ignore_index=True)
840
- for col in ['duration_in_state',status_of_interest,'TRACK_ID']:
841
- first_column = group_table.pop(col)
842
- group_table.insert(0, col, first_column)
1029
+
1030
+ if self.population=='pairs':
1031
+ for col in ['duration_in_state',status_of_interest, 'neighbor_population', 'reference_population', 'NEIGHBOR_ID', 'REFERENCE_ID']:
1032
+ first_column = group_table.pop(col)
1033
+ group_table.insert(0, col, first_column)
1034
+ else:
1035
+ for col in ['duration_in_state',status_of_interest,'TRACK_ID']:
1036
+ first_column = group_table.pop(col)
1037
+ group_table.insert(0, col, first_column)
1038
+
843
1039
  group_table.pop('FRAME')
844
- group_table = group_table.sort_values(by=['position','TRACK_ID',status_of_interest],ignore_index=True)
1040
+ group_table = group_table.sort_values(by=self.groupby_cols + [status_of_interest],ignore_index=True)
845
1041
  group_table = group_table.reset_index(drop=True)
846
1042
 
847
1043
 
@@ -972,7 +1168,7 @@ class TableUI(QMainWindow, Styles):
972
1168
  print(unique_cols[k])
973
1169
  for w,well_group in self.data.groupby('well_name'):
974
1170
  for pos,pos_group in well_group.groupby('pos_name'):
975
- for tid,group_track in pos_group.groupby('TRACK_ID'):
1171
+ for tid,group_track in pos_group.groupby(self.groupby_cols[1:]):
976
1172
  ax.plot(group_track["FRAME"], group_track[column_names[unique_cols[k]]],label=column_names[unique_cols[k]])
977
1173
  #ax.plot(self.data["FRAME"][row_idx_i], y, label=column_names[unique_cols[k]])
978
1174
  ax.legend()
@@ -986,7 +1182,7 @@ class TableUI(QMainWindow, Styles):
986
1182
  self.fig, self.ax = plt.subplots(1, 1, figsize=(4, 3))
987
1183
  self.scatter_wdw = FigureCanvas(self.fig, title="scatter")
988
1184
  self.ax.clear()
989
- for tid,group in self.data.groupby('TRACK_ID'):
1185
+ for tid,group in self.data.groupby(self.groupby_cols[1:]):
990
1186
  self.ax.plot(group[column_names[unique_cols[0]]], group[column_names[unique_cols[1]]], marker="o")
991
1187
  self.ax.set_xlabel(column_names[unique_cols[0]])
992
1188
  self.ax.set_ylabel(column_names[unique_cols[1]])
@@ -1009,7 +1205,7 @@ class TableUI(QMainWindow, Styles):
1009
1205
 
1010
1206
  for w,well_group in self.data.groupby('well_name'):
1011
1207
  for pos,pos_group in well_group.groupby('pos_name'):
1012
- for tid,group_track in pos_group.groupby('TRACK_ID'):
1208
+ for tid,group_track in pos_group.groupby(self.groupby_cols[1:]):
1013
1209
  self.ax.plot(group_track["FRAME"], group_track[column_names[unique_cols[0]]],c="k", alpha = 0.1)
1014
1210
  self.ax.set_xlabel(r"$t$ [frame]")
1015
1211
  self.ax.set_ylabel(column_names[unique_cols[0]])