celldetective 1.3.7.post2__py3-none-any.whl → 1.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -166,7 +166,7 @@ class DifferentiateColWidget(QWidget, Styles):
166
166
  layout.addLayout(measurement_layout)
167
167
 
168
168
  self.window_size_slider = QLabeledSlider()
169
- self.window_size_slider.setRange(1,np.nanmax(self.parent_window.data.FRAME.to_numpy()))
169
+ self.window_size_slider.setRange(1,int(np.nanmax(self.parent_window.data.FRAME.to_numpy())))
170
170
  self.window_size_slider.setValue(3)
171
171
  window_layout = QHBoxLayout()
172
172
  window_layout.addWidget(QLabel('window size: '), 25)
@@ -215,6 +215,108 @@ class DifferentiateColWidget(QWidget, Styles):
215
215
  self.parent_window.table_view.setModel(self.parent_window.model)
216
216
  self.close()
217
217
 
218
+
219
+
220
+ class OperationOnColsWidget(QWidget, Styles):
221
+
222
+ def __init__(self, parent_window, column1=None, column2=None, operation='divide'):
223
+
224
+ super().__init__()
225
+ self.parent_window = parent_window
226
+ self.column1 = column1
227
+ self.column2 = column2
228
+ self.operation = operation
229
+
230
+ self.setWindowTitle(self.operation)
231
+ # Create the QComboBox and add some items
232
+ center_window(self)
233
+
234
+ layout = QVBoxLayout(self)
235
+ layout.setContentsMargins(30,30,30,30)
236
+
237
+ self.col1_cb = QComboBox()
238
+ self.col1_cb.addItems(list(self.parent_window.data.columns))
239
+ if self.column1 is not None:
240
+ idx = self.col1_cb.findText(self.column1)
241
+ self.col1_cb.setCurrentIndex(idx)
242
+
243
+ numerator_layout = QHBoxLayout()
244
+ numerator_layout.addWidget(QLabel('column 1: '), 25)
245
+ numerator_layout.addWidget(self.col1_cb, 75)
246
+ layout.addLayout(numerator_layout)
247
+
248
+ self.col2_cb = QComboBox()
249
+ self.col2_cb.addItems(list(self.parent_window.data.columns))
250
+ if self.column2 is not None:
251
+ idx = self.col2_cb.findText(self.column2)
252
+ self.col2_cb.setCurrentIndex(idx)
253
+
254
+ denominator_layout = QHBoxLayout()
255
+ denominator_layout.addWidget(QLabel('column 2: '), 25)
256
+ denominator_layout.addWidget(self.col2_cb, 75)
257
+ layout.addLayout(denominator_layout)
258
+
259
+ self.submit_btn = QPushButton('Compute')
260
+ self.submit_btn.setStyleSheet(self.button_style_sheet)
261
+ self.submit_btn.clicked.connect(self.compute)
262
+ layout.addWidget(self.submit_btn, 30)
263
+
264
+ self.setAttribute(Qt.WA_DeleteOnClose)
265
+
266
+ def compute(self):
267
+
268
+ test = self._check_cols_before_operation()
269
+ if not test:
270
+ msgBox = QMessageBox()
271
+ msgBox.setIcon(QMessageBox.Warning)
272
+ msgBox.setText(f"Operation could not be performed, one of the column types is object...")
273
+ msgBox.setWindowTitle("Warning")
274
+ msgBox.setStandardButtons(QMessageBox.Ok)
275
+ returnValue = msgBox.exec()
276
+ if returnValue == QMessageBox.Ok:
277
+ return None
278
+ else:
279
+ return None
280
+ else:
281
+ if self.operation=='divide':
282
+ name = f"{self.col1_txt}/{self.col2_txt}"
283
+ with np.errstate(divide='ignore', invalid='ignore'):
284
+ res = np.true_divide(self.col1, self.col2)
285
+ res[res == np.inf] = np.nan
286
+ res[self.col1!=self.col1] = np.nan
287
+ res[self.col2!=self.col2] = np.nan
288
+ self.parent_window.data[name] = res
289
+
290
+ elif self.operation=='multiply':
291
+ name = f"{self.col1_txt}*{self.col2_txt}"
292
+ res = np.multiply(self.col1, self.col2)
293
+
294
+ elif self.operation=='add':
295
+ name = f"{self.col1_txt}+{self.col2_txt}"
296
+ res = np.add(self.col1, self.col2)
297
+
298
+ elif self.operation=='subtract':
299
+ name = f"{self.col1_txt}-{self.col2_txt}"
300
+ res = np.subtract(self.col1, self.col2)
301
+
302
+ self.parent_window.data[name] = res
303
+ self.parent_window.model = PandasModel(self.parent_window.data)
304
+ self.parent_window.table_view.setModel(self.parent_window.model)
305
+ self.close()
306
+
307
+ def _check_cols_before_operation(self):
308
+
309
+ self.col1_txt = self.col1_cb.currentText()
310
+ self.col2_txt = self.col2_cb.currentText()
311
+
312
+ self.col1 = self.parent_window.data[self.col1_txt].to_numpy()
313
+ self.col2 = self.parent_window.data[self.col2_txt].to_numpy()
314
+
315
+ test = np.all([self.col1.dtype!='O', self.col2.dtype!='O'])
316
+
317
+ return test
318
+
319
+
218
320
  class CalibrateColWidget(GenericOpColWidget):
219
321
 
220
322
  def __init__(self, *args, **kwargs):
@@ -563,12 +665,34 @@ class TableUI(QMainWindow, Styles):
563
665
  self.log_action = QAction('&Log (decimal)...', self)
564
666
  self.log_action.triggered.connect(self.take_log_of_selected_feature)
565
667
  #self.derivative_action.setShortcut("Ctrl+D")
566
- self.mathMenu.addAction(self.log_action)
668
+ self.mathMenu.addAction(self.log_action)
669
+
670
+
671
+ self.divide_action = QAction('&Divide...', self)
672
+ self.divide_action.triggered.connect(self.divide_signals)
673
+ #self.derivative_action.setShortcut("Ctrl+D")
674
+ self.mathMenu.addAction(self.divide_action)
675
+
676
+ self.multiply_action = QAction('&Multiply...', self)
677
+ self.multiply_action.triggered.connect(self.multiply_signals)
678
+ #self.derivative_action.setShortcut("Ctrl+D")
679
+ self.mathMenu.addAction(self.multiply_action)
680
+
681
+ self.add_action = QAction('&Add...', self)
682
+ self.add_action.triggered.connect(self.add_signals)
683
+ #self.derivative_action.setShortcut("Ctrl+D")
684
+ self.mathMenu.addAction(self.add_action)
685
+
686
+ self.subtract_action = QAction('&Subtract...', self)
687
+ self.subtract_action.triggered.connect(self.subtract_signals)
688
+ #self.derivative_action.setShortcut("Ctrl+D")
689
+ self.mathMenu.addAction(self.subtract_action)
567
690
 
568
- self.onehot_action = QAction('&One hot to categorical...', self)
569
- self.onehot_action.triggered.connect(self.transform_one_hot_cols_to_categorical)
570
- #self.onehot_action.setShortcut("Ctrl+D")
571
- self.mathMenu.addAction(self.onehot_action)
691
+
692
+ # self.onehot_action = QAction('&One hot to categorical...', self)
693
+ # self.onehot_action.triggered.connect(self.transform_one_hot_cols_to_categorical)
694
+ # #self.onehot_action.setShortcut("Ctrl+D")
695
+ # self.mathMenu.addAction(self.onehot_action)
572
696
 
573
697
  def collapse_pairs_in_neigh(self):
574
698
 
@@ -734,6 +858,96 @@ class TableUI(QMainWindow, Styles):
734
858
  pos_group.to_csv(pos[0]+os.sep.join(['output', 'tables', f'trajectories_{self.population}.csv']), index=False)
735
859
  print("Done...")
736
860
 
861
+ def divide_signals(self):
862
+
863
+ x = self.table_view.selectedIndexes()
864
+ col_idx = np.unique(np.array([l.column() for l in x]))
865
+ if isinstance(col_idx, (list, np.ndarray)):
866
+ cols = np.array(list(self.data.columns))
867
+ if len(col_idx)>0:
868
+ selected_col1 = str(cols[col_idx[0]])
869
+ if len(col_idx)>1:
870
+ selected_col2 = str(cols[col_idx[1]])
871
+ else:
872
+ selected_col2 = None
873
+ else:
874
+ selected_col1 = None
875
+ selected_col2 = None
876
+ else:
877
+ selected_col1 = None
878
+ selected_col2 = None
879
+
880
+ self.divWidget = OperationOnColsWidget(self, column1=selected_col1, column2=selected_col2, operation='divide')
881
+ self.divWidget.show()
882
+
883
+
884
+ def multiply_signals(self):
885
+
886
+ x = self.table_view.selectedIndexes()
887
+ col_idx = np.unique(np.array([l.column() for l in x]))
888
+ if isinstance(col_idx, (list, np.ndarray)):
889
+ cols = np.array(list(self.data.columns))
890
+ if len(col_idx)>0:
891
+ selected_col1 = str(cols[col_idx[0]])
892
+ if len(col_idx)>1:
893
+ selected_col2 = str(cols[col_idx[1]])
894
+ else:
895
+ selected_col2 = None
896
+ else:
897
+ selected_col1 = None
898
+ selected_col2 = None
899
+ else:
900
+ selected_col1 = None
901
+ selected_col2 = None
902
+
903
+ self.mulWidget = OperationOnColsWidget(self, column1=selected_col1, column2=selected_col2, operation='multiply')
904
+ self.mulWidget.show()
905
+
906
+ def add_signals(self):
907
+
908
+ x = self.table_view.selectedIndexes()
909
+ col_idx = np.unique(np.array([l.column() for l in x]))
910
+ if isinstance(col_idx, (list, np.ndarray)):
911
+ cols = np.array(list(self.data.columns))
912
+ if len(col_idx)>0:
913
+ selected_col1 = str(cols[col_idx[0]])
914
+ if len(col_idx)>1:
915
+ selected_col2 = str(cols[col_idx[1]])
916
+ else:
917
+ selected_col2 = None
918
+ else:
919
+ selected_col1 = None
920
+ selected_col2 = None
921
+ else:
922
+ selected_col1 = None
923
+ selected_col2 = None
924
+
925
+ self.addiWidget = OperationOnColsWidget(self, column1=selected_col1, column2=selected_col2, operation='add')
926
+ self.addiWidget.show()
927
+
928
+ def subtract_signals(self):
929
+
930
+ x = self.table_view.selectedIndexes()
931
+ col_idx = np.unique(np.array([l.column() for l in x]))
932
+ if isinstance(col_idx, (list, np.ndarray)):
933
+ cols = np.array(list(self.data.columns))
934
+ if len(col_idx)>0:
935
+ selected_col1 = str(cols[col_idx[0]])
936
+ if len(col_idx)>1:
937
+ selected_col2 = str(cols[col_idx[1]])
938
+ else:
939
+ selected_col2 = None
940
+ else:
941
+ selected_col1 = None
942
+ selected_col2 = None
943
+ else:
944
+ selected_col1 = None
945
+ selected_col2 = None
946
+
947
+ self.subWidget = OperationOnColsWidget(self, column1=selected_col1, column2=selected_col2, operation='subtract')
948
+ self.subWidget.show()
949
+
950
+
737
951
  def differenciate_selected_feature(self):
738
952
 
739
953
  # check only one col selected and assert is numerical
@@ -742,9 +956,12 @@ class TableUI(QMainWindow, Styles):
742
956
 
743
957
  x = self.table_view.selectedIndexes()
744
958
  col_idx = np.unique(np.array([l.column() for l in x]))
745
- if col_idx!=0:
959
+ if isinstance(col_idx, (list, np.ndarray)):
746
960
  cols = np.array(list(self.data.columns))
747
- selected_col = str(cols[col_idx][0])
961
+ if len(col_idx)>0:
962
+ selected_col = str(cols[col_idx[0]])
963
+ else:
964
+ selected_col = None
748
965
  else:
749
966
  selected_col = None
750
967
 
@@ -759,9 +976,12 @@ class TableUI(QMainWindow, Styles):
759
976
 
760
977
  x = self.table_view.selectedIndexes()
761
978
  col_idx = np.unique(np.array([l.column() for l in x]))
762
- if col_idx!=0:
979
+ if isinstance(col_idx, (list, np.ndarray)):
763
980
  cols = np.array(list(self.data.columns))
764
- selected_col = str(cols[col_idx][0])
981
+ if len(col_idx)>0:
982
+ selected_col = str(cols[col_idx[0]])
983
+ else:
984
+ selected_col = None
765
985
  else:
766
986
  selected_col = None
767
987
 
@@ -772,9 +992,12 @@ class TableUI(QMainWindow, Styles):
772
992
 
773
993
  x = self.table_view.selectedIndexes()
774
994
  col_idx = np.unique(np.array([l.column() for l in x]))
775
- if col_idx!=0:
995
+ if isinstance(col_idx, (list, np.ndarray)):
776
996
  cols = np.array(list(self.data.columns))
777
- selected_col = str(cols[col_idx][0])
997
+ if len(col_idx)>0:
998
+ selected_col = str(cols[col_idx[0]])
999
+ else:
1000
+ selected_col = None
778
1001
  else:
779
1002
  selected_col = None
780
1003
 
@@ -790,9 +1013,12 @@ class TableUI(QMainWindow, Styles):
790
1013
 
791
1014
  x = self.table_view.selectedIndexes()
792
1015
  col_idx = np.unique(np.array([l.column() for l in x]))
793
- if col_idx!=0:
1016
+ if isinstance(col_idx, (list, np.ndarray)):
794
1017
  cols = np.array(list(self.data.columns))
795
- selected_col = str(cols[col_idx][0])
1018
+ if len(col_idx)>0:
1019
+ selected_col = str(cols[col_idx[0]])
1020
+ else:
1021
+ selected_col = None
796
1022
  else:
797
1023
  selected_col = None
798
1024
 
@@ -804,11 +1030,14 @@ class TableUI(QMainWindow, Styles):
804
1030
 
805
1031
  x = self.table_view.selectedIndexes()
806
1032
  col_idx = np.unique(np.array([l.column() for l in x]))
807
- if list(col_idx):
1033
+ if isinstance(col_idx, (list, np.ndarray)):
808
1034
  cols = np.array(list(self.data.columns))
809
- selected_cols = cols[col_idx]
1035
+ if len(col_idx)>0:
1036
+ selected_col = str(cols[col_idx[0]])
1037
+ else:
1038
+ selected_col = None
810
1039
  else:
811
- selected_cols = None
1040
+ selected_col = None
812
1041
 
813
1042
  self.mergewidget = MergeOneHotWidget(self, selected_columns=selected_cols)
814
1043
  self.mergewidget.show()
@@ -874,7 +1103,7 @@ class TableUI(QMainWindow, Styles):
874
1103
  self.projection_option.setChecked(True)
875
1104
  self.projection_option.toggled.connect(self.enable_projection_options)
876
1105
  self.projection_op_cb = QComboBox()
877
- self.projection_op_cb.addItems(['mean','median','min','max', 'prod', 'sum'])
1106
+ self.projection_op_cb.addItems(['mean','median','min','max','first','last','prod','sum'])
878
1107
 
879
1108
  projection_layout = QHBoxLayout()
880
1109
  projection_layout.addWidget(self.projection_option, 33)
@@ -1044,8 +1273,11 @@ class TableUI(QMainWindow, Styles):
1044
1273
  all_cms = list(colormaps)
1045
1274
  for cm in all_cms:
1046
1275
  if hasattr(matplotlib.cm, str(cm).lower()):
1047
- self.cmap_cb.addColormap(cm.lower())
1048
-
1276
+ try:
1277
+ self.cmap_cb.addColormap(cm.lower())
1278
+ except:
1279
+ pass
1280
+
1049
1281
  hbox = QHBoxLayout()
1050
1282
  hbox.addWidget(QLabel('colormap: '), 33)
1051
1283
  hbox.addWidget(self.cmap_cb, 66)
@@ -932,7 +932,7 @@ class CellSizeViewer(StackVisualizer):
932
932
  with interactive sliders for diameter adjustment and circle display.
933
933
  """
934
934
 
935
- def __init__(self, initial_diameter=40, set_radius_in_list=False, diameter_slider_range=(0,200), parent_le=None, parent_list_widget=None, *args, **kwargs):
935
+ def __init__(self, initial_diameter=40, set_radius_in_list=False, diameter_slider_range=(0,500), parent_le=None, parent_list_widget=None, *args, **kwargs):
936
936
  # Initialize the widget and its attributes
937
937
 
938
938
  super().__init__(*args, **kwargs)
celldetective/io.py CHANGED
@@ -16,7 +16,7 @@ import concurrent.futures
16
16
  from csbdeep.utils import normalize_mi_ma
17
17
  from csbdeep.io import save_tiff_imagej_compatible
18
18
 
19
- import skimage.io as skio
19
+ import imageio.v2 as imageio
20
20
  from skimage.measure import regionprops_table, label
21
21
 
22
22
  from btrack.datasets import cell_config
@@ -400,6 +400,23 @@ def get_experiment_metadata(experiment):
400
400
  metadata = ConfigSectionMap(config, "Metadata")
401
401
  return metadata
402
402
 
403
+ def get_experiment_labels(experiment):
404
+
405
+ config = get_config(experiment)
406
+ wells = get_experiment_wells(experiment)
407
+ nbr_of_wells = len(wells)
408
+
409
+ labels = ConfigSectionMap(config, "Labels")
410
+ for k in list(labels.keys()):
411
+ values = labels[k].split(',')
412
+ if nbr_of_wells != len(values):
413
+ values = [str(s) for s in np.linspace(0, nbr_of_wells - 1, nbr_of_wells)]
414
+ if np.all([s.isnumeric() for s in values]):
415
+ values = [float(s) for s in values]
416
+ labels.update({k: values})
417
+
418
+ return labels
419
+
403
420
 
404
421
  def get_experiment_concentrations(experiment, dtype=str):
405
422
 
@@ -982,10 +999,8 @@ def load_experiment_tables(experiment, population='targets', well_option='*', po
982
999
  wells = get_experiment_wells(experiment)
983
1000
 
984
1001
  movie_prefix = ConfigSectionMap(config, "MovieSettings")["movie_prefix"]
985
- concentrations = get_experiment_concentrations(experiment, dtype=float)
986
- cell_types = get_experiment_cell_types(experiment)
987
- antibodies = get_experiment_antibodies(experiment)
988
- pharmaceutical_agents = get_experiment_pharmaceutical_agents(experiment)
1002
+
1003
+ labels = get_experiment_labels(experiment)
989
1004
  metadata = get_experiment_metadata(experiment) # None or dict of metadata
990
1005
  well_labels = _extract_labels_from_config(config, len(wells))
991
1006
 
@@ -1001,14 +1016,8 @@ def load_experiment_tables(experiment, population='targets', well_option='*', po
1001
1016
 
1002
1017
  well_name, well_number = extract_well_name_and_number(well_path)
1003
1018
  widx = well_indices[k]
1004
-
1005
1019
  well_alias = well_labels[widx]
1006
1020
 
1007
- well_concentration = concentrations[widx]
1008
- well_antibody = antibodies[widx]
1009
- well_cell_type = cell_types[widx]
1010
- well_pharmaceutical_agent = pharmaceutical_agents[widx]
1011
-
1012
1021
  positions = get_positions_in_well(well_path)
1013
1022
  if position_indices is not None:
1014
1023
  try:
@@ -1037,10 +1046,13 @@ def load_experiment_tables(experiment, population='targets', well_option='*', po
1037
1046
  df_pos['well_name'] = well_name
1038
1047
  df_pos['pos_name'] = pos_name
1039
1048
 
1040
- df_pos['concentration'] = well_concentration
1041
- df_pos['antibody'] = well_antibody
1042
- df_pos['cell_type'] = well_cell_type
1043
- df_pos['pharmaceutical_agent'] = well_pharmaceutical_agent
1049
+ for k in list(labels.keys()):
1050
+ values = labels[k]
1051
+ try:
1052
+ df_pos[k] = values[widx]
1053
+ except Exception as e:
1054
+ print(f"{e=}")
1055
+
1044
1056
  if metadata is not None:
1045
1057
  keys = list(metadata.keys())
1046
1058
  for k in keys:
@@ -1052,10 +1064,6 @@ def load_experiment_tables(experiment, population='targets', well_option='*', po
1052
1064
  pos_dict = {'pos_path': pos_path, 'pos_index': real_pos_index, 'pos_name': pos_name, 'table_path': table,
1053
1065
  'stack_path': stack_path,'well_path': well_path, 'well_index': real_well_index, 'well_name': well_name,
1054
1066
  'well_number': well_number, 'well_alias': well_alias}
1055
- # if metadata is not None:
1056
- # keys = list(metadata.keys())
1057
- # for k in keys:
1058
- # pos_dict.update({k: metadata[k]})
1059
1067
 
1060
1068
  df_pos_info.append(pos_dict)
1061
1069
 
@@ -3335,7 +3343,7 @@ def load_frames(img_nums, stack_path, scale=None, normalize_input=True, dtype=fl
3335
3343
  """
3336
3344
 
3337
3345
  try:
3338
- frames = skio.imread(stack_path, key=img_nums, plugin="tifffile")
3346
+ frames = imageio.imread(stack_path, key=img_nums)
3339
3347
  except Exception as e:
3340
3348
  print(
3341
3349
  f'Error in loading the frame {img_nums} {e}. Please check that the experiment channel information is consistent with the movie being read.')
@@ -8,10 +8,12 @@ from .utils import _estimate_scale_factor, _extract_channel_indices
8
8
  from pathlib import Path
9
9
  from tqdm import tqdm
10
10
  import numpy as np
11
- from celldetective.io import _view_on_napari, locate_labels, locate_stack, _view_on_napari, _check_label_dims
11
+ from celldetective.io import _view_on_napari, locate_labels, locate_stack, _view_on_napari, _check_label_dims, auto_correct_masks
12
12
  from celldetective.filters import * #rework this to give a name
13
13
  from celldetective.utils import interpolate_nan_multichannel,_rearrange_multichannel_frame, _fix_no_contrast, zoom_multiframes, _rescale_labels, rename_intensity_column, mask_edges, _prep_stardist_model, _prep_cellpose_model, estimate_unreliable_edge,_get_normalize_kwargs_from_config, _segment_image_with_stardist_model, _segment_image_with_cellpose_model
14
14
  from stardist import fill_label_holes
15
+ from stardist.matching import matching
16
+
15
17
  import scipy.ndimage as ndi
16
18
  from skimage.segmentation import watershed
17
19
  from skimage.feature import peak_local_max
@@ -717,5 +719,50 @@ def train_segmentation_model(config, use_gpu=True):
717
719
  cmd = f'python "{script_path}" --config "{config}" --use_gpu "{use_gpu}"'
718
720
  subprocess.call(cmd, shell=True)
719
721
 
722
+
723
+ def merge_instance_segmentation(labels, iou_matching_threshold=0.05, mode='OR'):
724
+
725
+ label_reference = labels[0]
726
+ for i in range(1,len(labels)):
727
+
728
+ label_to_merge = labels[i]
729
+ pairs = matching(label_reference,label_to_merge, thresh=0.5, criterion='iou', report_matches=True).matched_pairs
730
+ scores = matching(label_reference,label_to_merge, thresh=0.5, criterion='iou', report_matches=True).matched_scores
731
+
732
+ accepted_pairs = []
733
+ for k,p in enumerate(pairs):
734
+ s = scores[k]
735
+ if s > iou_matching_threshold:
736
+ accepted_pairs.append(p)
737
+
738
+ merge = np.copy(label_reference)
739
+
740
+ for p in accepted_pairs:
741
+ merge[np.where(merge==p[0])] = 0.
742
+ cdt1 = label_reference==p[0]
743
+ cdt2 = label_to_merge==p[1]
744
+ if mode=='OR':
745
+ cdt = np.logical_or(cdt1, cdt2)
746
+ elif mode=='AND':
747
+ cdt = np.logical_and(cdt1, cdt2)
748
+ elif mode=='XOR':
749
+ cdt = np.logical_xor(cdt1,cdt2)
750
+ loc_i, loc_j = np.where(cdt)
751
+ merge[loc_i, loc_j] = p[0]
752
+
753
+ cells_to_ignore = [p[1] for p in accepted_pairs]
754
+ for c in cells_to_ignore:
755
+ label_to_merge[label_to_merge==c] = 0
756
+
757
+ label_to_merge[label_to_merge!=0] = label_to_merge[label_to_merge!=0] + int(np.amax(label_reference))
758
+ merge[label_to_merge!=0] = label_to_merge[label_to_merge!=0]
759
+
760
+ label_reference = merge
761
+
762
+ merge = auto_correct_masks(merge)
763
+
764
+ return merge
765
+
766
+
720
767
  if __name__ == "__main__":
721
768
  print(segment(None,'test'))
celldetective/signals.py CHANGED
@@ -33,6 +33,7 @@ import time
33
33
  import math
34
34
  import pandas as pd
35
35
  from pandas.api.types import is_numeric_dtype
36
+ from scipy.stats import median_abs_deviation
36
37
 
37
38
  abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
38
39
 
@@ -680,11 +681,22 @@ class SignalDetectionModel(object):
680
681
  if 'label' in model_config:
681
682
  self.label = model_config['label']
682
683
 
683
- self.n_channels = self.model_class.layers[0].input_shape[0][-1]
684
- self.model_signal_length = self.model_class.layers[0].input_shape[0][-2]
685
- self.n_classes = self.model_class.layers[-1].output_shape[-1]
684
+ try:
685
+ self.n_channels = self.model_class.layers[0].input_shape[0][-1]
686
+ self.model_signal_length = self.model_class.layers[0].input_shape[0][-2]
687
+ self.n_classes = self.model_class.layers[-1].output_shape[-1]
688
+ model_class_input_shape = self.model_class.layers[0].input_shape[0]
689
+ model_reg_input_shape = self.model_reg.layers[0].input_shape[0]
690
+ except AttributeError:
691
+ self.n_channels = self.model_class.input_shape[-1] #self.model_class.layers[0].input.shape[0][-1]
692
+ self.model_signal_length = self.model_class.input_shape[-2] #self.model_class.layers[0].input[0].shape[0][-2]
693
+ self.n_classes = self.model_class.output_shape[-1] #self.model_class.layers[-1].output[0].shape[-1]
694
+ model_class_input_shape = self.model_class.input_shape
695
+ model_reg_input_shape = self.model_reg.input_shape
696
+ except Exception as e:
697
+ print(e)
686
698
 
687
- assert self.model_class.layers[0].input_shape[0] == self.model_reg.layers[0].input_shape[0], f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
699
+ assert model_class_input_shape==model_reg_input_shape, f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
688
700
 
689
701
  return True
690
702
 
@@ -1015,8 +1027,15 @@ class SignalDetectionModel(object):
1015
1027
  # plt.plot(self.x[i,:,0])
1016
1028
  # plt.show()
1017
1029
 
1018
- assert self.x.shape[-1] == self.model_class.layers[0].input_shape[0][-1], f"Shape mismatch between the input shape and the model input shape..."
1019
- assert self.x.shape[-2] == self.model_class.layers[0].input_shape[0][-2], f"Shape mismatch between the input shape and the model input shape..."
1030
+ try:
1031
+ n_channels = self.model_class.layers[0].input_shape[0][-1]
1032
+ model_signal_length = self.model_class.layers[0].input_shape[0][-2]
1033
+ except AttributeError:
1034
+ n_channels = self.model_class.input_shape[-1]
1035
+ model_signal_length = self.model_class.input_shape[-2]
1036
+
1037
+ assert self.x.shape[-1] == n_channels, f"Shape mismatch between the input shape and the model input shape..."
1038
+ assert self.x.shape[-2] == model_signal_length, f"Shape mismatch between the input shape and the model input shape..."
1020
1039
 
1021
1040
  self.class_predictions_one_hot = self.model_class.predict(self.x)
1022
1041
  self.class_predictions = self.class_predictions_one_hot.argmax(axis=1)
@@ -1072,8 +1091,15 @@ class SignalDetectionModel(object):
1072
1091
  normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
1073
1092
  )
1074
1093
 
1075
- assert self.x.shape[-1] == self.model_reg.layers[0].input_shape[0][-1], f"Shape mismatch between the input shape and the model input shape..."
1076
- assert self.x.shape[-2] == self.model_reg.layers[0].input_shape[0][-2], f"Shape mismatch between the input shape and the model input shape..."
1094
+ try:
1095
+ n_channels = self.model_reg.layers[0].input_shape[0][-1]
1096
+ model_signal_length = self.model_reg.layers[0].input_shape[0][-2]
1097
+ except AttributeError:
1098
+ n_channels = self.model_reg.input_shape[-1]
1099
+ model_signal_length = self.model_reg.input_shape[-2]
1100
+
1101
+ assert self.x.shape[-1] == n_channels, f"Shape mismatch between the input shape and the model input shape..."
1102
+ assert self.x.shape[-2] == model_signal_length, f"Shape mismatch between the input shape and the model input shape..."
1077
1103
 
1078
1104
  if np.any(self.class_predictions==0):
1079
1105
  self.time_predictions = self.model_reg.predict(self.x[self.class_predictions==0])*self.model_signal_length
@@ -2749,7 +2775,7 @@ def sliding_msd_drift(x, y, timeline, window, mode='bi', n_points_migration=7,
2749
2775
 
2750
2776
  return s_diffusion, s_velocity
2751
2777
 
2752
- def columnwise_mean(matrix, min_nbr_values = 1):
2778
+ def columnwise_mean(matrix, min_nbr_values = 1, projection='mean'):
2753
2779
 
2754
2780
  """
2755
2781
  Calculate the column-wise mean and standard deviation of non-NaN elements in the input matrix.
@@ -2788,12 +2814,16 @@ def columnwise_mean(matrix, min_nbr_values = 1):
2788
2814
  values = matrix[:,k]
2789
2815
  values = values[values==values]
2790
2816
  if len(values[values==values])>min_nbr_values:
2791
- mean_line[k] = np.nanmean(values)
2792
- mean_line_std[k] = np.nanstd(values)
2817
+ if projection=='mean':
2818
+ mean_line[k] = np.nanmean(values)
2819
+ mean_line_std[k] = np.nanstd(values)
2820
+ elif projection=='median':
2821
+ mean_line[k] = np.nanmedian(values)
2822
+ mean_line_std[k] = median_abs_deviation(values, center=np.nanmedian, nan_policy='omit')
2793
2823
  return mean_line, mean_line_std
2794
2824
 
2795
2825
 
2796
- def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], return_matrix=False, forced_max_duration=None, min_nbr_values=2,conflict_mode='mean'):
2826
+ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], return_matrix=False, forced_max_duration=None, min_nbr_values=2,conflict_mode='mean', projection='mean'):
2797
2827
 
2798
2828
  """
2799
2829
  Calculate the mean and standard deviation of a specified signal for tracks of a given class in the input DataFrame.
@@ -2884,7 +2914,7 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
2884
2914
  signal_matrix[trackid,timeline_shifted.astype(int)] = signal
2885
2915
  trackid+=1
2886
2916
 
2887
- mean_signal, std_signal = columnwise_mean(signal_matrix, min_nbr_values=min_nbr_values)
2917
+ mean_signal, std_signal = columnwise_mean(signal_matrix, min_nbr_values=min_nbr_values, projection=projection)
2888
2918
  actual_timeline = np.linspace(-max_duration, max_duration, 2*max_duration+1)
2889
2919
  if return_matrix:
2890
2920
  return mean_signal, std_signal, actual_timeline, signal_matrix
celldetective/tracking.py CHANGED
@@ -441,9 +441,14 @@ def interpolate_per_track(group_df):
441
441
 
442
442
  """
443
443
 
444
- interpolated_group = group_df.interpolate(method='linear',limit_direction="both")
444
+ for c in list(group_df.columns):
445
+ group_df_new_dtype = group_df[c].infer_objects(copy=False)
446
+ if group_df_new_dtype.dtype!='O':
447
+ group_df[c] = group_df_new_dtype.interpolate(method='linear',limit_direction="both")
448
+
449
+ #interpolated_group = group_df.interpolate(method='linear',limit_direction="both")
445
450
 
446
- return interpolated_group
451
+ return group_df
447
452
 
448
453
  def interpolate_nan_properties(trajectories, track_label="TRACK_ID"):
449
454
 
celldetective/utils.py CHANGED
@@ -1794,7 +1794,7 @@ def ConfigSectionMap(path,section):
1794
1794
 
1795
1795
  """
1796
1796
 
1797
- Config = configparser.ConfigParser()
1797
+ Config = configparser.ConfigParser(interpolation=None)
1798
1798
  Config.read(path)
1799
1799
  dict1 = {}
1800
1800
  try:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: celldetective
3
- Version: 1.3.7.post2
3
+ Version: 1.3.8
4
4
  Summary: description
5
5
  Home-page: http://github.com/remyeltorro/celldetective
6
6
  Author: Rémy Torro