celldetective 1.1.1.post4__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. celldetective/__init__.py +2 -1
  2. celldetective/extra_properties.py +62 -34
  3. celldetective/gui/__init__.py +1 -0
  4. celldetective/gui/analyze_block.py +2 -1
  5. celldetective/gui/classifier_widget.py +8 -7
  6. celldetective/gui/control_panel.py +50 -6
  7. celldetective/gui/layouts.py +5 -4
  8. celldetective/gui/neighborhood_options.py +10 -8
  9. celldetective/gui/plot_signals_ui.py +39 -11
  10. celldetective/gui/process_block.py +413 -95
  11. celldetective/gui/retrain_segmentation_model_options.py +17 -4
  12. celldetective/gui/retrain_signal_model_options.py +106 -6
  13. celldetective/gui/signal_annotator.py +25 -5
  14. celldetective/gui/signal_annotator2.py +2708 -0
  15. celldetective/gui/signal_annotator_options.py +3 -1
  16. celldetective/gui/survival_ui.py +15 -6
  17. celldetective/gui/tableUI.py +235 -39
  18. celldetective/io.py +537 -421
  19. celldetective/measure.py +919 -969
  20. celldetective/models/pair_signal_detection/blank +0 -0
  21. celldetective/neighborhood.py +426 -354
  22. celldetective/relative_measurements.py +648 -0
  23. celldetective/scripts/analyze_signals.py +1 -1
  24. celldetective/scripts/measure_cells.py +28 -8
  25. celldetective/scripts/measure_relative.py +103 -0
  26. celldetective/scripts/segment_cells.py +5 -5
  27. celldetective/scripts/track_cells.py +4 -1
  28. celldetective/scripts/train_segmentation_model.py +23 -18
  29. celldetective/scripts/train_signal_model.py +33 -0
  30. celldetective/signals.py +402 -8
  31. celldetective/tracking.py +8 -2
  32. celldetective/utils.py +93 -0
  33. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/METADATA +8 -8
  34. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/RECORD +38 -34
  35. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/WHEEL +1 -1
  36. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/LICENSE +0 -0
  37. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/entry_points.txt +0 -0
  38. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/top_level.txt +0 -0
@@ -423,7 +423,7 @@ class ConfigSegmentationModelTraining(QMainWindow, Styles):
423
423
  )
424
424
  if self.dataset_folder is not None:
425
425
 
426
- subfiles = glob(self.dataset_folder+"/*.tif")
426
+ subfiles = glob(self.dataset_folder+os.sep+"*.tif")
427
427
  if len(subfiles)>0:
428
428
  print(f'found {len(subfiles)} files in folder')
429
429
  self.data_folder_label.setText(self.dataset_folder[:16]+'...')
@@ -459,14 +459,26 @@ class ConfigSegmentationModelTraining(QMainWindow, Styles):
459
459
  self.data_folder_label.setToolTip('')
460
460
  self.cancel_dataset.setVisible(False)
461
461
 
462
+ def load_stardist_train_config(self):
463
+
464
+ config = os.sep.join([self.pretrained_model,"config.json"])
465
+ if os.path.exists(config):
466
+ with open(config, 'r') as f:
467
+ config = json.load(f)
468
+ if 'train_batch_size' in config:
469
+ bs = config['train_batch_size']
470
+ self.bs_le.setText(str(bs).replace('.',','))
471
+ if 'train_learning_rate' in config:
472
+ lr = config['train_learning_rate']
473
+ self.lr_le.setText(str(lr).replace('.',','))
462
474
 
463
475
  def load_pretrained_config(self):
464
476
 
465
477
  f = open(os.sep.join([self.pretrained_model,"config_input.json"]))
466
478
  data = json.load(f)
467
479
  channels = data["channels"]
468
- self.seg_folder = self.pretrained_model.split('/')[-2]
469
- self.model_name = self.pretrained_model.split('/')[-1]
480
+ self.seg_folder = self.pretrained_model.split(os.sep)[-2]
481
+ self.model_name = self.pretrained_model.split(os.sep)[-1]
470
482
  if self.model_name.startswith('CP') and self.seg_folder=='segmentation_generic':
471
483
  channels = ['brightfield_channel', 'live_nuclei_channel']
472
484
  if self.model_name=="CP_nuclei":
@@ -484,6 +496,7 @@ class ConfigSegmentationModelTraining(QMainWindow, Styles):
484
496
  if model_type=='stardist':
485
497
  self.stardist_model.setChecked(True)
486
498
  self.cellpose_model.setChecked(False)
499
+ self.load_stardist_train_config()
487
500
  else:
488
501
  self.stardist_model.setChecked(False)
489
502
  self.cellpose_model.setChecked(True)
@@ -593,7 +606,7 @@ class ConfigSegmentationModelTraining(QMainWindow, Styles):
593
606
 
594
607
  print(training_instructions)
595
608
 
596
- model_folder = '/'.join([self.software_models_dir,model_name, ''])
609
+ model_folder = os.sep.join([self.software_models_dir,model_name, ''])
597
610
  print(model_folder)
598
611
  if not os.path.exists(model_folder):
599
612
  os.mkdir(model_folder)
@@ -3,11 +3,11 @@ from PyQt5.QtCore import Qt, QSize
3
3
  from PyQt5.QtGui import QDoubleValidator, QIntValidator, QIcon
4
4
  from celldetective.gui.gui_utils import center_window, FeatureChoice, ListWidget, QHSeperationLine, FigureCanvas, GeometryChoice, OperationChoice
5
5
  from celldetective.gui.layouts import ChannelNormGenerator
6
- from superqt import QLabeledDoubleRangeSlider, QLabeledDoubleSlider,QLabeledSlider
6
+ from superqt import QLabeledDoubleRangeSlider, QLabeledDoubleSlider, QLabeledSlider, QSearchableComboBox
7
7
  from superqt.fonticon import icon
8
8
  from fonticon_mdi6 import MDI6
9
9
  from celldetective.utils import extract_experiment_channels, get_software_location
10
- from celldetective.io import interpret_tracking_configuration, load_frames, locate_signal_dataset, get_signal_datasets_list
10
+ from celldetective.io import interpret_tracking_configuration, load_frames, locate_signal_dataset, get_signal_datasets_list, load_experiment_tables
11
11
  from celldetective.measure import compute_haralick_features, contour_of_instance_segmentation
12
12
  from celldetective.signals import train_signal_model
13
13
  import numpy as np
@@ -24,6 +24,7 @@ from datetime import datetime
24
24
  import pandas as pd
25
25
  from functools import partial
26
26
  from celldetective.gui import Styles
27
+ from pandas.api.types import is_numeric_dtype
27
28
 
28
29
  class ConfigSignalModelTraining(QMainWindow, Styles):
29
30
 
@@ -32,7 +33,7 @@ class ConfigSignalModelTraining(QMainWindow, Styles):
32
33
 
33
34
  """
34
35
 
35
- def __init__(self, parent_window=None):
36
+ def __init__(self, parent_window=None, signal_mode='single-cells'):
36
37
 
37
38
  super().__init__()
38
39
  self.parent_window = parent_window
@@ -43,7 +44,16 @@ class ConfigSignalModelTraining(QMainWindow, Styles):
43
44
  self.soft_path = get_software_location()
44
45
  self.pretrained_model = None
45
46
  self.dataset_folder = None
46
- self.signal_models_dir = self.soft_path+os.sep+os.sep.join(['celldetective','models','signal_detection'])
47
+ self.current_neighborhood = None
48
+ self.reference_population = None
49
+ self.neighbor_population = None
50
+ self.signal_mode = signal_mode
51
+
52
+ if self.signal_mode=='single-cells':
53
+ self.signal_models_dir = self.soft_path+os.sep+os.sep.join(['celldetective','models','signal_detection'])
54
+ elif self.signal_mode=='pairs':
55
+ self.signal_models_dir = self.soft_path+os.sep+os.sep.join(['celldetective','models','pair_signal_detection'])
56
+ self.mode = 'pairs'
47
57
 
48
58
  self.onlyFloat = QDoubleValidator()
49
59
  self.onlyInt = QIntValidator()
@@ -272,6 +282,14 @@ class ConfigSignalModelTraining(QMainWindow, Styles):
272
282
  modelname_layout.addWidget(self.modelname_le, 70)
273
283
  layout.addLayout(modelname_layout)
274
284
 
285
+ if self.signal_mode=='pairs':
286
+ neighborhood_layout = QHBoxLayout()
287
+ neighborhood_layout.addWidget(QLabel('neighborhood of interest: '), 30)
288
+ self.neighborhood_choice_cb = QSearchableComboBox()
289
+ self.fill_available_neighborhoods()
290
+ neighborhood_layout.addWidget(self.neighborhood_choice_cb, 70)
291
+ layout.addLayout(neighborhood_layout)
292
+
275
293
  classname_layout = QHBoxLayout()
276
294
  classname_layout.addWidget(QLabel('event name: '), 30)
277
295
  self.class_name_le = QLineEdit()
@@ -311,6 +329,10 @@ class ConfigSignalModelTraining(QMainWindow, Styles):
311
329
  self.ch_norm = ChannelNormGenerator(self, mode='signals')
312
330
  layout.addLayout(self.ch_norm)
313
331
 
332
+ if self.signal_mode=='pairs':
333
+ self.neighborhood_choice_cb.currentIndexChanged.connect(self.neighborhood_changed)
334
+ self.neighborhood_changed()
335
+
314
336
  model_length_layout = QHBoxLayout()
315
337
  model_length_layout.addWidget(QLabel('Max signal length: '), 30)
316
338
  self.model_length_slider = QLabeledSlider()
@@ -323,6 +345,84 @@ class ConfigSignalModelTraining(QMainWindow, Styles):
323
345
  model_length_layout.addWidget(self.model_length_slider, 70)
324
346
  layout.addLayout(model_length_layout)
325
347
 
348
+ def neighborhood_changed(self):
349
+
350
+ neigh = self.neighborhood_choice_cb.currentText()
351
+ self.current_neighborhood = neigh.replace('target_ref_','').replace('effector_ref_','')
352
+ self.reference_population = ['targets' if 'target' in neigh else 'effectors'][0]
353
+ if 'target' in neigh:
354
+ if 'self' in neigh:
355
+ self.neighbor_population = 'targets'
356
+ else:
357
+ self.neighbor_population = 'effectors'
358
+ else:
359
+ if 'self' in neigh:
360
+ self.neighbor_population = 'effectors'
361
+ else:
362
+ self.neighbor_population = 'targets'
363
+
364
+ print(f'Current neighborhood: {self.current_neighborhood}')
365
+ print(f'New reference population: {self.reference_population}')
366
+ print(f'New neighbor population: {self.neighbor_population}')
367
+
368
+ # reload reference signals / neighbor signals / pair signals
369
+ # fill the channel cbs
370
+ self.df_reference = self.dataframes[self.reference_population]
371
+ self.df_neighbor = self.dataframes[self.neighbor_population]
372
+ self.df_pairs = load_experiment_tables(self.parent_window.exp_dir, population='pairs', load_pickle=False)
373
+
374
+ self.df_reference = self.df_reference.rename(columns=lambda x: 'reference_' + x)
375
+ num_cols_reference = [c for c in list(self.df_reference.columns) if is_numeric_dtype(self.df_reference[c])]
376
+ self.df_neighbor = self.df_neighbor.rename(columns=lambda x: 'neighbor_' + x)
377
+ num_cols_neighbor = [c for c in list(self.df_neighbor.columns) if is_numeric_dtype(self.df_neighbor[c])]
378
+ self.df_pairs = self.df_pairs.rename(columns=lambda x: 'pair_' + x)
379
+ num_cols_pairs = [c for c in list(self.df_pairs.columns) if is_numeric_dtype(self.df_pairs[c])]
380
+
381
+ self.signals = ['--'] + num_cols_pairs + num_cols_reference + num_cols_neighbor
382
+
383
+ for cb in self.ch_norm.channel_cbs:
384
+ # try:
385
+ # cb.disconnect()
386
+ # except:
387
+ # pass
388
+ cb.clear()
389
+ cb.addItems(self.signals)
390
+
391
+ def fill_available_neighborhoods(self):
392
+
393
+ df_targets = load_experiment_tables(self.parent_window.exp_dir, population='targets', load_pickle=True)
394
+ df_effectors = load_experiment_tables(self.parent_window.exp_dir, population='effectors', load_pickle=True)
395
+
396
+ self.dataframes = {
397
+ 'targets': df_targets,
398
+ 'effectors': df_effectors,
399
+ }
400
+
401
+ self.neighborhood_cols = []
402
+ self.reference_populations = []
403
+ self.neighbor_populations = []
404
+ if df_targets is not None:
405
+ self.neighborhood_cols.extend(['target_ref_'+c for c in list(df_targets.columns) if c.startswith('neighborhood')])
406
+ self.reference_populations.extend(['targets' for c in list(df_targets.columns) if c.startswith('neighborhood')])
407
+ for c in list(df_targets.columns):
408
+ if c.startswith('neighborhood') and '_2_' in c:
409
+ self.neighbor_populations.append('effectors')
410
+ elif c.startswith('neighborhood') and 'self' in c:
411
+ self.neighbor_populations.append('targets')
412
+
413
+ if df_effectors is not None:
414
+ self.neighborhood_cols.extend(['effector_ref_'+c for c in list(df_effectors.columns) if c.startswith('neighborhood')])
415
+ self.reference_populations.extend(['effectors' for c in list(df_effectors.columns) if c.startswith('neighborhood')])
416
+ for c in list(df_effectors.columns):
417
+ if c.startswith('neighborhood') and '_2_' in c:
418
+ self.neighbor_populations.append('targets')
419
+ elif c.startswith('neighborhood') and 'self' in c:
420
+ self.neighbor_populations.append('effectors')
421
+
422
+ print(f"The following neighborhoods were detected: {self.neighborhood_cols=} {self.reference_populations=} {self.neighbor_populations=}")
423
+
424
+ self.neighborhood_choice_cb.addItems(self.neighborhood_cols)
425
+
326
426
  def showDialog_pretrained(self):
327
427
 
328
428
  self.pretrained_model = QFileDialog.getExistingDirectory(
@@ -481,9 +581,9 @@ class ConfigSignalModelTraining(QMainWindow, Styles):
481
581
  training_instructions = {'model_name': model_name,'pretrained': pretrained_model, 'channel_option': channels, 'normalization_percentile': normalization_mode,
482
582
  'normalization_clip': clip_values,'normalization_values': norm_values, 'model_signal_length': signal_length,
483
583
  'recompile_pretrained': recompile_op, 'ds': data_folders, 'augmentation_factor': aug_factor, 'validation_split': val_split,
484
- 'learning_rate': lr, 'batch_size': bs, 'epochs': epochs, 'label': self.class_name_le.text()}
584
+ 'learning_rate': lr, 'batch_size': bs, 'epochs': epochs, 'label': self.class_name_le.text(), 'neighborhood_of_interest': self.current_neighborhood, 'reference_population': self.reference_population, 'neighbor_population': self.neighbor_population}
485
585
 
486
- model_folder = self.signal_models_dir + model_name + os.sep
586
+ model_folder = self.signal_models_dir +os.sep+ model_name + os.sep
487
587
  if not os.path.exists(model_folder):
488
588
  os.mkdir(model_folder)
489
589
 
@@ -57,6 +57,7 @@ class SignalAnnotator(QMainWindow, Styles):
57
57
 
58
58
  self.screen_height = self.parent_window.parent_window.parent_window.screen_height
59
59
  self.screen_width = self.parent_window.parent_window.parent_window.screen_width
60
+ self.value_magnitude = 1
60
61
 
61
62
  # default params
62
63
  self.class_name = 'class'
@@ -826,6 +827,8 @@ class SignalAnnotator(QMainWindow, Styles):
826
827
 
827
828
  def plot_signals(self):
828
829
 
830
+ range_values = []
831
+
829
832
  try:
830
833
  yvalues = []
831
834
  for i in range(len(self.signal_choice_cb)):
@@ -841,6 +844,8 @@ class SignalAnnotator(QMainWindow, Styles):
841
844
  xdata = self.df_tracks.loc[self.df_tracks['TRACK_ID'] == self.track_of_interest, 'FRAME'].to_numpy()
842
845
  ydata = self.df_tracks.loc[
843
846
  self.df_tracks['TRACK_ID'] == self.track_of_interest, signal_choice].to_numpy()
847
+
848
+ range_values.extend(ydata)
844
849
 
845
850
  xdata = xdata[ydata == ydata] # remove nan
846
851
  ydata = ydata[ydata == ydata]
@@ -863,6 +868,20 @@ class SignalAnnotator(QMainWindow, Styles):
863
868
  self.cell_fcanvas.canvas.draw()
864
869
  except Exception as e:
865
870
  print(f"Plot signals: {e=}")
871
+
872
+ if len(range_values)>0:
873
+ range_values = np.array(range_values)
874
+ if len(range_values[range_values==range_values])>0:
875
+ if len(range_values[range_values>0])>0:
876
+ self.value_magnitude = np.nanpercentile(range_values, 1)
877
+ else:
878
+ self.value_magnitude = 1
879
+ self.non_log_ymin = 0.98*np.nanmin(range_values)
880
+ self.non_log_ymax = np.nanmax(range_values)*1.02
881
+ if self.cell_ax.get_yscale()=='linear':
882
+ self.cell_ax.set_ylim(self.non_log_ymin, self.non_log_ymax)
883
+ else:
884
+ self.cell_ax.set_ylim(self.value_magnitude, self.non_log_ymax)
866
885
 
867
886
  def extract_scatter_from_trajectories(self):
868
887
 
@@ -1294,19 +1313,20 @@ class SignalAnnotator(QMainWindow, Styles):
1294
1313
  """
1295
1314
 
1296
1315
  try:
1297
- if self.cell_ax.get_yscale() == 'linear':
1316
+ if self.cell_ax.get_yscale()=='linear':
1317
+ ymin,ymax = self.cell_ax.get_ylim()
1298
1318
  self.cell_ax.set_yscale('log')
1299
- self.log_btn.setIcon(icon(MDI6.math_log, color="#1565c0"))
1319
+ self.log_btn.setIcon(icon(MDI6.math_log,color="#1565c0"))
1320
+ self.cell_ax.set_ylim(self.value_magnitude, ymax)
1300
1321
  else:
1301
1322
  self.cell_ax.set_yscale('linear')
1302
- self.log_btn.setIcon(icon(MDI6.math_log, color="black"))
1323
+ self.log_btn.setIcon(icon(MDI6.math_log,color="black"))
1303
1324
  except Exception as e:
1304
1325
  print(e)
1305
1326
 
1306
- # self.cell_ax.autoscale()
1327
+ #self.cell_ax.autoscale()
1307
1328
  self.cell_fcanvas.canvas.draw_idle()
1308
1329
 
1309
-
1310
1330
  class MeasureAnnotator(SignalAnnotator):
1311
1331
 
1312
1332
  def __init__(self, parent_window=None):