celldetective 1.3.1__py3-none-any.whl → 1.3.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. celldetective/_version.py +1 -1
  2. celldetective/events.py +2 -0
  3. celldetective/gui/classifier_widget.py +51 -3
  4. celldetective/gui/control_panel.py +9 -3
  5. celldetective/gui/generic_signal_plot.py +161 -2
  6. celldetective/gui/gui_utils.py +90 -1
  7. celldetective/gui/measurement_options.py +35 -32
  8. celldetective/gui/plot_signals_ui.py +8 -3
  9. celldetective/gui/process_block.py +36 -114
  10. celldetective/gui/retrain_segmentation_model_options.py +3 -1
  11. celldetective/gui/signal_annotator.py +53 -26
  12. celldetective/gui/signal_annotator2.py +17 -30
  13. celldetective/gui/survival_ui.py +7 -3
  14. celldetective/gui/tableUI.py +300 -183
  15. celldetective/gui/thresholds_gui.py +195 -199
  16. celldetective/gui/viewers.py +267 -13
  17. celldetective/io.py +110 -10
  18. celldetective/measure.py +128 -88
  19. celldetective/models/segmentation_effectors/ricm_bf_all_last/config_input.json +79 -0
  20. celldetective/models/segmentation_effectors/ricm_bf_all_last/ricm_bf_all_last +0 -0
  21. celldetective/models/segmentation_effectors/ricm_bf_all_last/training_instructions.json +37 -0
  22. celldetective/models/segmentation_effectors/test-transfer/config_input.json +39 -0
  23. celldetective/models/segmentation_effectors/test-transfer/test-transfer +0 -0
  24. celldetective/neighborhood.py +154 -69
  25. celldetective/relative_measurements.py +128 -4
  26. celldetective/scripts/measure_cells.py +3 -3
  27. celldetective/signals.py +207 -213
  28. celldetective/utils.py +16 -0
  29. {celldetective-1.3.1.dist-info → celldetective-1.3.3.post1.dist-info}/METADATA +11 -10
  30. {celldetective-1.3.1.dist-info → celldetective-1.3.3.post1.dist-info}/RECORD +34 -29
  31. {celldetective-1.3.1.dist-info → celldetective-1.3.3.post1.dist-info}/WHEEL +1 -1
  32. {celldetective-1.3.1.dist-info → celldetective-1.3.3.post1.dist-info}/LICENSE +0 -0
  33. {celldetective-1.3.1.dist-info → celldetective-1.3.3.post1.dist-info}/entry_points.txt +0 -0
  34. {celldetective-1.3.1.dist-info → celldetective-1.3.3.post1.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import numpy as np
2
2
  from celldetective.io import auto_load_number_of_frames, load_frames
3
3
  from celldetective.filters import *
4
4
  from celldetective.segmentation import filter_image, threshold_image
5
- from celldetective.measure import contour_of_instance_segmentation
5
+ from celldetective.measure import contour_of_instance_segmentation, extract_blobs_in_image
6
6
  from celldetective.utils import _get_img_num_per_channel, estimate_unreliable_edge
7
7
  from tifffile import imread
8
8
  import matplotlib.pyplot as plt
@@ -13,7 +13,7 @@ import os
13
13
 
14
14
  from PyQt5.QtWidgets import QWidget, QHBoxLayout, QPushButton, QLabel, QComboBox, QLineEdit, QListWidget, QShortcut
15
15
  from PyQt5.QtCore import Qt, QSize
16
- from PyQt5.QtGui import QKeySequence
16
+ from PyQt5.QtGui import QKeySequence, QDoubleValidator
17
17
  from celldetective.gui.gui_utils import FigureCanvas, center_window, QuickSliderLayout, QHSeperationLine, ThresholdLineEdit
18
18
  from celldetective.gui import Styles
19
19
  from superqt import QLabeledDoubleSlider, QLabeledSlider, QLabeledDoubleRangeSlider
@@ -22,7 +22,11 @@ from fonticon_mdi6 import MDI6
22
22
  from matplotlib_scalebar.scalebar import ScaleBar
23
23
  import gc
24
24
  from celldetective.utils import mask_edges
25
- from scipy.ndimage import shift
25
+ from scipy.ndimage import shift, grey_dilation
26
+ from skimage.feature import blob_dog
27
+ import math
28
+ from skimage.morphology import disk
29
+ from matplotlib.patches import Circle
26
30
 
27
31
  class StackVisualizer(QWidget, Styles):
28
32
 
@@ -80,6 +84,7 @@ class StackVisualizer(QWidget, Styles):
80
84
  self.imshow_kwargs = imshow_kwargs
81
85
  self.PxToUm = PxToUm
82
86
  self.init_contrast = False
87
+ self.channel_trigger = False
83
88
 
84
89
  self.load_stack() # need to get stack, frame etc
85
90
  self.generate_figure_canvas()
@@ -123,13 +128,8 @@ class StackVisualizer(QWidget, Styles):
123
128
 
124
129
  def locate_image_virtual(self):
125
130
  # Locate the stack of images if provided as a file
126
- self.stack_length = auto_load_number_of_frames(self.stack_path)
127
- if self.stack_length is None:
128
- stack = imread(self.stack_path)
129
- self.stack_length = len(stack)
130
- del stack
131
- gc.collect()
132
131
 
132
+ self.stack_length = auto_load_number_of_frames(self.stack_path)
133
133
  self.mid_time = self.stack_length // 2
134
134
  self.img_num_per_channel = _get_img_num_per_channel(np.arange(self.n_channels), self.stack_length, self.n_channels)
135
135
 
@@ -146,7 +146,7 @@ class StackVisualizer(QWidget, Styles):
146
146
  self.fig, self.ax = plt.subplots(figsize=(5,5),tight_layout=True) #figsize=(5, 5)
147
147
  self.canvas = FigureCanvas(self.fig, title=self.window_title, interactive=True)
148
148
  self.ax.clear()
149
- self.im = self.ax.imshow(self.init_frame, cmap='gray', interpolation='none', **self.imshow_kwargs)
149
+ self.im = self.ax.imshow(self.init_frame, cmap='gray', interpolation='none', zorder=0, **self.imshow_kwargs)
150
150
  if self.PxToUm is not None:
151
151
  scalebar = ScaleBar(self.PxToUm,
152
152
  "um",
@@ -250,12 +250,23 @@ class StackVisualizer(QWidget, Styles):
250
250
  self.last_frame = load_frames(self.img_num_per_channel[self.target_channel, self.stack_length-1],
251
251
  self.stack_path,
252
252
  normalize_input=False).astype(float)[:,:,0]
253
- self.change_frame(self.frame_slider.value())
253
+ self.change_frame_from_channel_switch(self.frame_slider.value())
254
+ self.channel_trigger = False
254
255
  self.init_contrast = False
255
256
 
257
+ def change_frame_from_channel_switch(self, value):
258
+
259
+ self.channel_trigger = True
260
+ self.change_frame(value)
261
+
256
262
  def change_frame(self, value):
263
+
257
264
  # Change the displayed frame based on slider value
258
-
265
+ if self.channel_trigger:
266
+ self.switch_from_channel = True
267
+ else:
268
+ self.switch_from_channel = False
269
+
259
270
  if self.mode=='virtual':
260
271
 
261
272
  self.init_frame = load_frames(self.img_num_per_channel[self.target_channel, value],
@@ -279,7 +290,7 @@ class StackVisualizer(QWidget, Styles):
279
290
 
280
291
  def closeEvent(self, event):
281
292
  # Event handler for closing the widget
282
- self.canvas.close()
293
+ self.canvas.close()
283
294
 
284
295
 
285
296
  class ThresholdedStackVisualizer(StackVisualizer):
@@ -617,6 +628,249 @@ class CellEdgeVisualizer(StackVisualizer):
617
628
 
618
629
  self.edge_labels = contour_of_instance_segmentation(self.init_label, edge_size)
619
630
 
631
+ class SpotDetectionVisualizer(StackVisualizer):
632
+
633
+ def __init__(self, parent_channel_cb=None, parent_diameter_le=None, parent_threshold_le=None, cell_type='targets', labels=None, *args, **kwargs):
634
+
635
+ super().__init__(*args, **kwargs)
636
+
637
+ self.cell_type = cell_type
638
+ self.labels = labels
639
+ self.detection_channel = self.target_channel
640
+ self.parent_channel_cb = parent_channel_cb
641
+ self.parent_diameter_le = parent_diameter_le
642
+ self.parent_threshold_le = parent_threshold_le
643
+ self.spot_sizes = []
644
+
645
+ self.floatValidator = QDoubleValidator()
646
+ self.init_scatter()
647
+ self.generate_detection_channel()
648
+ self.generate_spot_detection_params()
649
+ self.generate_add_measurement_btn()
650
+ self.load_labels()
651
+ self.change_frame(self.mid_time)
652
+
653
+ self.ax.callbacks.connect('xlim_changed', self.update_marker_sizes)
654
+ self.ax.callbacks.connect('ylim_changed', self.update_marker_sizes)
655
+
656
+ self.apply_diam_btn.clicked.connect(self.detect_and_display_spots)
657
+ self.apply_thresh_btn.clicked.connect(self.detect_and_display_spots)
658
+
659
+ self.channels_cb.setCurrentIndex(self.target_channel)
660
+ self.detection_channel_cb.setCurrentIndex(self.target_channel)
661
+
662
+ def update_marker_sizes(self, event=None):
663
+
664
+ # Get axis bounds
665
+ xlim = self.ax.get_xlim()
666
+ ylim = self.ax.get_ylim()
667
+
668
+ # Data-to-pixel scale
669
+ ax_width_in_pixels = self.ax.bbox.width
670
+ ax_height_in_pixels = self.ax.bbox.height
671
+
672
+ x_scale = (xlim[1] - xlim[0]) / ax_width_in_pixels
673
+ y_scale = (ylim[1] - ylim[0]) / ax_height_in_pixels
674
+
675
+ # Choose the smaller scale for square pixels
676
+ scale = min(x_scale, y_scale)
677
+
678
+ # Convert radius_px to data units
679
+ if len(self.spot_sizes)>0:
680
+
681
+ radius_data_units = self.spot_sizes / scale
682
+
683
+ # Convert to scatter `s` size (points squared)
684
+ radius_pts = radius_data_units * (72. / self.fig.dpi )
685
+ size = np.pi * (radius_pts ** 2)
686
+
687
+ # Update scatter sizes
688
+ self.spot_scat.set_sizes(size)
689
+ self.fig.canvas.draw_idle()
690
+
691
+ def init_scatter(self):
692
+ self.spot_scat = self.ax.scatter([],[], s=50, facecolors='none', edgecolors='tab:red',zorder=100)
693
+ self.canvas.canvas.draw()
694
+
695
+ def change_frame(self, value):
696
+
697
+ super().change_frame(value)
698
+ if not self.switch_from_channel:
699
+ self.reset_detection()
700
+
701
+ if self.mode=='virtual':
702
+ self.init_label = imread(self.mask_paths[value])
703
+ self.target_img = load_frames(self.img_num_per_channel[self.detection_channel, value],
704
+ self.stack_path,
705
+ normalize_input=False).astype(float)[:,:,0]
706
+ elif self.mode=='direct':
707
+ self.init_label = self.labels[value,:,:]
708
+ self.target_img = self.stack[value,:,:,self.detection_channel].copy()
709
+
710
+ def detect_and_display_spots(self):
711
+
712
+ self.reset_detection()
713
+ self.control_valid_parameters() # set current diam and threshold
714
+ blobs_filtered = extract_blobs_in_image(self.target_img, self.init_label,threshold=self.thresh, diameter=self.diameter)
715
+ if blobs_filtered is not None:
716
+ self.spot_positions = np.array([[x,y] for y,x,_ in blobs_filtered])
717
+
718
+ self.spot_sizes = np.sqrt(2)*np.array([sig for _,_,sig in blobs_filtered])
719
+ print(f"{self.spot_sizes=}")
720
+ #radius_pts = self.spot_sizes * (self.fig.dpi / 72.0)
721
+ #sizes = np.pi*(radius_pts**2)
722
+
723
+ self.spot_scat.set_offsets(self.spot_positions)
724
+ #self.spot_scat.set_sizes(sizes)
725
+ self.update_marker_sizes()
726
+ self.canvas.canvas.draw()
727
+
728
+ def reset_detection(self):
729
+
730
+ self.ax.scatter([], []).get_offsets()
731
+ empty_offset = np.ma.masked_array([0, 0], mask=True)
732
+ self.spot_scat.set_offsets(empty_offset)
733
+ self.canvas.canvas.draw()
734
+
735
+ def load_labels(self):
736
+
737
+ # Load the cell labels
738
+ if self.labels is not None:
739
+
740
+ if isinstance(self.labels, list):
741
+ self.labels = np.array(self.labels)
742
+
743
+ assert self.labels.ndim==3,'Wrong dimensions for the provided labels, expect TXY'
744
+ assert len(self.labels)==self.stack_length
745
+
746
+ self.mode = 'direct'
747
+ self.init_label = self.labels[self.mid_time,:,:]
748
+ else:
749
+ self.mode = 'virtual'
750
+ assert isinstance(self.stack_path, str)
751
+ assert self.stack_path.endswith('.tif')
752
+ self.locate_labels_virtual()
753
+
754
+ def locate_labels_virtual(self):
755
+ # Locate virtual labels
756
+
757
+ labels_path = str(Path(self.stack_path).parent.parent) + os.sep + f'labels_{self.cell_type}' + os.sep
758
+ self.mask_paths = natsorted(glob(labels_path + '*.tif'))
759
+
760
+ if len(self.mask_paths) == 0:
761
+
762
+ msgBox = QMessageBox()
763
+ msgBox.setIcon(QMessageBox.Critical)
764
+ msgBox.setText("No labels were found for the selected cells. Abort.")
765
+ msgBox.setWindowTitle("Critical")
766
+ msgBox.setStandardButtons(QMessageBox.Ok)
767
+ returnValue = msgBox.exec()
768
+ self.close()
769
+
770
+ self.init_label = imread(self.mask_paths[self.frame_slider.value()])
771
+
772
+ def generate_detection_channel(self):
773
+
774
+ assert self.channel_names is not None
775
+ assert len(self.channel_names)==self.n_channels
776
+
777
+ channel_layout = QHBoxLayout()
778
+ channel_layout.setContentsMargins(15,0,15,0)
779
+ channel_layout.addWidget(QLabel('Detection\nchannel: '), 25)
780
+
781
+ self.detection_channel_cb = QComboBox()
782
+ self.detection_channel_cb.addItems(self.channel_names)
783
+ self.detection_channel_cb.currentIndexChanged.connect(self.set_detection_channel_index)
784
+ channel_layout.addWidget(self.detection_channel_cb, 75)
785
+ self.canvas.layout.addLayout(channel_layout)
786
+
787
+ def set_detection_channel_index(self, value):
788
+
789
+ self.detection_channel = value
790
+ if self.mode == 'direct':
791
+ self.last_frame = self.stack[-1,:,:,self.target_channel]
792
+ elif self.mode == 'virtual':
793
+ self.target_img = load_frames(self.img_num_per_channel[self.detection_channel, self.stack_length-1],
794
+ self.stack_path,
795
+ normalize_input=False).astype(float)[:,:,0]
796
+
797
+ def generate_spot_detection_params(self):
798
+
799
+ self.spot_diam_le = QLineEdit('1')
800
+ self.spot_diam_le.setValidator(self.floatValidator)
801
+ self.apply_diam_btn = QPushButton('Set')
802
+ self.apply_diam_btn.setStyleSheet(self.button_style_sheet_2)
803
+
804
+ self.spot_thresh_le = QLineEdit('0')
805
+ self.spot_thresh_le.setValidator(self.floatValidator)
806
+ self.apply_thresh_btn = QPushButton('Set')
807
+ self.apply_thresh_btn.setStyleSheet(self.button_style_sheet_2)
808
+
809
+ self.spot_diam_le.textChanged.connect(self.control_valid_parameters)
810
+ self.spot_thresh_le.textChanged.connect(self.control_valid_parameters)
811
+
812
+ spot_diam_layout = QHBoxLayout()
813
+ spot_diam_layout.setContentsMargins(15,0,15,0)
814
+ spot_diam_layout.addWidget(QLabel('Spot diameter: '), 25)
815
+ spot_diam_layout.addWidget(self.spot_diam_le, 65)
816
+ spot_diam_layout.addWidget(self.apply_diam_btn, 10)
817
+ self.canvas.layout.addLayout(spot_diam_layout)
818
+
819
+ spot_thresh_layout = QHBoxLayout()
820
+ spot_thresh_layout.setContentsMargins(15,0,15,0)
821
+ spot_thresh_layout.addWidget(QLabel('Detection\nthreshold: '), 25)
822
+ spot_thresh_layout.addWidget(self.spot_thresh_le, 65)
823
+ spot_thresh_layout.addWidget(self.apply_thresh_btn, 10)
824
+ self.canvas.layout.addLayout(spot_thresh_layout)
825
+
826
+ def generate_add_measurement_btn(self):
827
+
828
+ add_hbox = QHBoxLayout()
829
+ self.add_measurement_btn = QPushButton('Add measurement')
830
+ self.add_measurement_btn.clicked.connect(self.set_measurement_in_parent_list)
831
+ self.add_measurement_btn.setIcon(icon(MDI6.plus,color="white"))
832
+ self.add_measurement_btn.setIconSize(QSize(20, 20))
833
+ self.add_measurement_btn.setStyleSheet(self.button_style_sheet)
834
+ add_hbox.addWidget(QLabel(''),33)
835
+ add_hbox.addWidget(self.add_measurement_btn, 33)
836
+ add_hbox.addWidget(QLabel(''),33)
837
+ self.canvas.layout.addLayout(add_hbox)
838
+
839
+ def control_valid_parameters(self):
840
+
841
+ valid_diam = False
842
+ try:
843
+ self.diameter = float(self.spot_diam_le.text().replace(',','.'))
844
+ valid_diam = True
845
+ except:
846
+ valid_diam = False
847
+
848
+ valid_thresh = False
849
+ try:
850
+ self.thresh = float(self.spot_thresh_le.text().replace(',','.'))
851
+ valid_thresh = True
852
+ except:
853
+ valid_thresh = False
854
+
855
+ if valid_diam and valid_thresh:
856
+ self.apply_diam_btn.setEnabled(True)
857
+ self.apply_thresh_btn.setEnabled(True)
858
+ self.add_measurement_btn.setEnabled(True)
859
+ else:
860
+ self.apply_diam_btn.setEnabled(False)
861
+ self.apply_thresh_btn.setEnabled(False)
862
+ self.add_measurement_btn.setEnabled(False)
863
+
864
+ def set_measurement_in_parent_list(self):
865
+
866
+ if self.parent_channel_cb is not None:
867
+ self.parent_channel_cb.setCurrentIndex(self.detection_channel)
868
+ if self.parent_diameter_le is not None:
869
+ self.parent_diameter_le.setText(self.spot_diam_le.text())
870
+ if self.parent_threshold_le is not None:
871
+ self.parent_threshold_le.setText(self.spot_thresh_le.text())
872
+ self.close()
873
+
620
874
  class CellSizeViewer(StackVisualizer):
621
875
 
622
876
  """
celldetective/io.py CHANGED
@@ -14,7 +14,7 @@ from btrack.datasets import cell_config
14
14
  from magicgui import magicgui
15
15
  from csbdeep.io import save_tiff_imagej_compatible
16
16
  from pathlib import Path, PurePath
17
- from shutil import copyfile
17
+ from shutil import copyfile, rmtree
18
18
  from celldetective.utils import ConfigSectionMap, extract_experiment_channels, _extract_labels_from_config, get_zenodo_files, download_zenodo_file
19
19
  import json
20
20
  from skimage.measure import regionprops_table
@@ -24,6 +24,54 @@ import concurrent.futures
24
24
  from tifffile import imwrite
25
25
  from stardist import fill_label_holes
26
26
 
27
+ def extract_experiment_from_well(well_path):
28
+ if not well_path.endswith(os.sep):
29
+ well_path += os.sep
30
+ exp_path_blocks = well_path.split(os.sep)[:-2]
31
+ experiment = os.sep.join(exp_path_blocks)
32
+ return experiment
33
+
34
+ def extract_well_from_position(pos_path):
35
+ if not pos_path.endswith(os.sep):
36
+ pos_path += os.sep
37
+ well_path_blocks = pos_path.split(os.sep)[:-2]
38
+ well_path = os.sep.join(well_path_blocks)+os.sep
39
+ return well_path
40
+
41
+ def extract_experiment_from_position(pos_path):
42
+ if not pos_path.endswith(os.sep):
43
+ pos_path += os.sep
44
+ exp_path_blocks = pos_path.split(os.sep)[:-3]
45
+ experiment = os.sep.join(exp_path_blocks)
46
+ return experiment
47
+
48
+ def collect_experiment_metadata(pos_path=None, well_path=None):
49
+
50
+ if pos_path is not None:
51
+ if not pos_path.endswith(os.sep):
52
+ pos_path += os.sep
53
+ experiment = extract_experiment_from_position(pos_path)
54
+ well_path = extract_well_from_position(pos_path)
55
+ elif well_path is not None:
56
+ if not well_path.endswith(os.sep):
57
+ well_path += os.sep
58
+ experiment = extract_experiment_from_well(well_path)
59
+
60
+ wells = list(get_experiment_wells(experiment))
61
+ idx = wells.index(well_path)
62
+ well_name, well_nbr = extract_well_name_and_number(well_path)
63
+ if pos_path is not None:
64
+ pos_name = extract_position_name(pos_path)
65
+ else:
66
+ pos_name = 0
67
+ concentrations = get_experiment_concentrations(experiment, dtype=float)
68
+ cell_types = get_experiment_cell_types(experiment)
69
+ antibodies = get_experiment_antibodies(experiment)
70
+ pharmaceutical_agents = get_experiment_pharmaceutical_agents(experiment)
71
+
72
+ return {"pos_path": pos_path, "pos_name": pos_name, "well_path": well_path, "well_name": well_name, "well_nbr": well_nbr, "experiment": experiment, "antibody": antibodies[idx], "concentration": concentrations[idx], "cell_type": cell_types[idx], "pharmaceutical_agent": pharmaceutical_agents[idx]}
73
+
74
+
27
75
  def get_experiment_wells(experiment):
28
76
 
29
77
  """
@@ -626,10 +674,20 @@ def locate_stack(position, prefix='Aligned'):
626
674
  stack_path = glob(position + os.sep.join(['movie', f'{prefix}*.tif']))
627
675
  assert len(stack_path) > 0, f"No movie with prefix {prefix} found..."
628
676
  stack = imread(stack_path[0].replace('\\', '/'))
677
+ stack_length = auto_load_number_of_frames(stack_path[0])
678
+
629
679
  if stack.ndim == 4:
630
680
  stack = np.moveaxis(stack, 1, -1)
631
681
  elif stack.ndim == 3:
632
- stack = stack[:, :, :, np.newaxis]
682
+ if min(stack.shape)!=stack_length:
683
+ channel_axis = np.argmin(stack.shape)
684
+ if channel_axis!=(stack.ndim-1):
685
+ stack = np.moveaxis(stack, channel_axis, -1)
686
+ stack = stack[np.newaxis, :, :, :]
687
+ else:
688
+ stack = stack[:, :, :, np.newaxis]
689
+ elif stack.ndim==2:
690
+ stack = stack[np.newaxis, :, :, np.newaxis]
633
691
 
634
692
  return stack
635
693
 
@@ -712,12 +770,17 @@ def fix_missing_labels(position, population='target', prefix='Aligned'):
712
770
 
713
771
  if population.lower() == "target" or population.lower() == "targets":
714
772
  label_path = natsorted(glob(position + os.sep.join(["labels_targets", "*.tif"])))
773
+ path = position + os.sep + "labels_targets"
715
774
  elif population.lower() == "effector" or population.lower() == "effectors":
716
775
  label_path = natsorted(glob(position + os.sep.join(["labels_effectors", "*.tif"])))
776
+ path = position + os.sep + "labels_effectors"
717
777
 
718
- path = os.path.split(label_path[0])[0]
719
- int_valid = [int(lbl.split(os.sep)[-1].split('.')[0]) for lbl in label_path]
720
- to_create = [x for x in all_frames if x not in int_valid]
778
+ if label_path!=[]:
779
+ #path = os.path.split(label_path[0])[0]
780
+ int_valid = [int(lbl.split(os.sep)[-1].split('.')[0]) for lbl in label_path]
781
+ to_create = [x for x in all_frames if x not in int_valid]
782
+ else:
783
+ to_create = all_frames
721
784
  to_create = [str(x).zfill(4)+'.tif' for x in to_create]
722
785
  for file in to_create:
723
786
  imwrite(os.sep.join([path, file]), template)
@@ -859,6 +922,7 @@ def auto_load_number_of_frames(stack_path):
859
922
  return None
860
923
 
861
924
  stack_path = stack_path.replace('\\','/')
925
+ n_channels=1
862
926
 
863
927
  with TiffFile(stack_path) as tif:
864
928
  try:
@@ -868,7 +932,8 @@ def auto_load_number_of_frames(stack_path):
868
932
  tif_tags[name] = value
869
933
  img_desc = tif_tags["ImageDescription"]
870
934
  attr = img_desc.split("\n")
871
- except:
935
+ n_channels = int(attr[np.argmax([s.startswith("channels") for s in attr])].split("=")[-1])
936
+ except Exception as e:
872
937
  pass
873
938
  try:
874
939
  # Try nframes
@@ -895,6 +960,10 @@ def auto_load_number_of_frames(stack_path):
895
960
  if 'len_movie' not in locals():
896
961
  stack = imread(stack_path)
897
962
  len_movie = len(stack)
963
+ if len_movie==n_channels and stack.ndim==3:
964
+ len_movie = 1
965
+ if stack.ndim==2:
966
+ len_movie = 1
898
967
  del stack
899
968
  gc.collect()
900
969
 
@@ -1532,6 +1601,8 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
1532
1601
 
1533
1602
  for k, sq in enumerate(squares):
1534
1603
  print(f"ROI: {sq}")
1604
+ pad_to_256=False
1605
+
1535
1606
  xmin = int(sq[0, 1])
1536
1607
  xmax = int(sq[2, 1])
1537
1608
  if xmax < xmin:
@@ -1543,8 +1614,9 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
1543
1614
  print(f"{xmin=};{xmax=};{ymin=};{ymax=}")
1544
1615
  frame = viewer.layers['Image'].data[t][xmin:xmax, ymin:ymax]
1545
1616
  if frame.shape[1] < 256 or frame.shape[0] < 256:
1546
- print("crop too small!")
1547
- continue
1617
+ pad_to_256 = True
1618
+ print("Crop too small! Padding with zeros to reach 256*256 pixels...")
1619
+ #continue
1548
1620
  multichannel = [frame]
1549
1621
  for i in range(len(channel_indices) - 1):
1550
1622
  try:
@@ -1552,8 +1624,21 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
1552
1624
  multichannel.append(frame)
1553
1625
  except:
1554
1626
  pass
1555
- multichannel = np.array(multichannel)
1556
- save_tiff_imagej_compatible(annotation_folder + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}_labelled.tif", labels_layer[xmin:xmax,ymin:ymax].astype(np.int16), axes='YX')
1627
+ multichannel = np.array(multichannel)
1628
+ lab = labels_layer[xmin:xmax,ymin:ymax].astype(np.int16)
1629
+ if pad_to_256:
1630
+ shape = multichannel.shape
1631
+ pad_length_x = max([0,256 - multichannel.shape[1]])
1632
+ if pad_length_x>0 and pad_length_x%2==1:
1633
+ pad_length_x += 1
1634
+ pad_length_y = max([0,256 - multichannel.shape[2]])
1635
+ if pad_length_y>0 and pad_length_y%2==1:
1636
+ pad_length_y += 1
1637
+ padded_image = np.array([np.pad(im, ((pad_length_x//2,pad_length_x//2), (pad_length_y//2,pad_length_y//2)), mode='constant') for im in multichannel])
1638
+ padded_label = np.pad(lab,((pad_length_x//2,pad_length_x//2), (pad_length_y//2,pad_length_y//2)), mode='constant')
1639
+ lab = padded_label; multichannel = padded_image;
1640
+
1641
+ save_tiff_imagej_compatible(annotation_folder + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}_labelled.tif", lab, axes='YX')
1557
1642
  save_tiff_imagej_compatible(annotation_folder + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}.tif", multichannel, axes='CYX')
1558
1643
  info = {"spatial_calibration": spatial_calibration, "channels": list(channel_names), 'cell_type': ct, 'antibody': ab, 'concentration': conc, 'pharmaceutical_agent': pa}
1559
1644
  info_name = annotation_folder + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}.json"
@@ -1592,6 +1677,7 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
1592
1677
  population += 's'
1593
1678
  output_folder = position + f'labels_{population}{os.sep}'
1594
1679
 
1680
+ print(f"{stack.shape}")
1595
1681
  viewer = napari.Viewer()
1596
1682
  viewer.add_image(stack, channel_axis=-1, colormap=["gray"] * stack.shape[-1])
1597
1683
  viewer.add_labels(labels.astype(int), name='segmentation', opacity=0.4)
@@ -1779,10 +1865,24 @@ def get_segmentation_models_list(mode='targets', return_path=False):
1779
1865
 
1780
1866
  available_models = natsorted(glob(modelpath + '*/'))
1781
1867
  available_models = [m.replace('\\', '/').split('/')[-2] for m in available_models]
1868
+
1869
+ # Auto model cleanup
1870
+ to_remove = []
1871
+ for model in available_models:
1872
+ path = modelpath + model
1873
+ files = glob(path+os.sep+"*")
1874
+ if path+os.sep+"config_input.json" not in files:
1875
+ rmtree(path)
1876
+ to_remove.append(model)
1877
+ for m in to_remove:
1878
+ available_models.remove(m)
1879
+
1880
+
1782
1881
  for rm in repository_models:
1783
1882
  if rm not in available_models:
1784
1883
  available_models.append(rm)
1785
1884
 
1885
+
1786
1886
  if not return_path:
1787
1887
  return available_models
1788
1888
  else: