celldetective 1.3.2__py3-none-any.whl → 1.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. celldetective/__main__.py +30 -4
  2. celldetective/_version.py +1 -1
  3. celldetective/extra_properties.py +21 -0
  4. celldetective/filters.py +15 -2
  5. celldetective/gui/InitWindow.py +28 -34
  6. celldetective/gui/analyze_block.py +3 -498
  7. celldetective/gui/classifier_widget.py +1 -1
  8. celldetective/gui/control_panel.py +100 -29
  9. celldetective/gui/generic_signal_plot.py +35 -18
  10. celldetective/gui/gui_utils.py +143 -2
  11. celldetective/gui/layouts.py +7 -6
  12. celldetective/gui/measurement_options.py +38 -43
  13. celldetective/gui/plot_measurements.py +5 -13
  14. celldetective/gui/plot_signals_ui.py +30 -30
  15. celldetective/gui/process_block.py +66 -197
  16. celldetective/gui/retrain_segmentation_model_options.py +3 -1
  17. celldetective/gui/signal_annotator.py +50 -32
  18. celldetective/gui/signal_annotator2.py +7 -4
  19. celldetective/gui/styles.py +13 -0
  20. celldetective/gui/survival_ui.py +8 -21
  21. celldetective/gui/tableUI.py +1 -2
  22. celldetective/gui/thresholds_gui.py +195 -205
  23. celldetective/gui/viewers.py +262 -12
  24. celldetective/io.py +85 -11
  25. celldetective/measure.py +128 -88
  26. celldetective/models/segmentation_effectors/ricm_bf_all_last/config_input.json +79 -0
  27. celldetective/models/segmentation_effectors/ricm_bf_all_last/ricm_bf_all_last +0 -0
  28. celldetective/models/segmentation_effectors/ricm_bf_all_last/training_instructions.json +37 -0
  29. celldetective/models/segmentation_effectors/test-transfer/config_input.json +39 -0
  30. celldetective/models/segmentation_effectors/test-transfer/test-transfer +0 -0
  31. celldetective/neighborhood.py +0 -2
  32. celldetective/scripts/measure_cells.py +21 -9
  33. celldetective/signals.py +77 -66
  34. celldetective/tracking.py +19 -13
  35. {celldetective-1.3.2.dist-info → celldetective-1.3.4.dist-info}/METADATA +12 -10
  36. {celldetective-1.3.2.dist-info → celldetective-1.3.4.dist-info}/RECORD +41 -36
  37. {celldetective-1.3.2.dist-info → celldetective-1.3.4.dist-info}/WHEEL +1 -1
  38. tests/test_qt.py +5 -3
  39. {celldetective-1.3.2.dist-info → celldetective-1.3.4.dist-info}/LICENSE +0 -0
  40. {celldetective-1.3.2.dist-info → celldetective-1.3.4.dist-info}/entry_points.txt +0 -0
  41. {celldetective-1.3.2.dist-info → celldetective-1.3.4.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import numpy as np
2
2
  from celldetective.io import auto_load_number_of_frames, load_frames
3
3
  from celldetective.filters import *
4
4
  from celldetective.segmentation import filter_image, threshold_image
5
- from celldetective.measure import contour_of_instance_segmentation
5
+ from celldetective.measure import contour_of_instance_segmentation, extract_blobs_in_image
6
6
  from celldetective.utils import _get_img_num_per_channel, estimate_unreliable_edge
7
7
  from tifffile import imread
8
8
  import matplotlib.pyplot as plt
@@ -13,7 +13,7 @@ import os
13
13
 
14
14
  from PyQt5.QtWidgets import QWidget, QHBoxLayout, QPushButton, QLabel, QComboBox, QLineEdit, QListWidget, QShortcut
15
15
  from PyQt5.QtCore import Qt, QSize
16
- from PyQt5.QtGui import QKeySequence
16
+ from PyQt5.QtGui import QKeySequence, QDoubleValidator
17
17
  from celldetective.gui.gui_utils import FigureCanvas, center_window, QuickSliderLayout, QHSeperationLine, ThresholdLineEdit
18
18
  from celldetective.gui import Styles
19
19
  from superqt import QLabeledDoubleSlider, QLabeledSlider, QLabeledDoubleRangeSlider
@@ -80,6 +80,7 @@ class StackVisualizer(QWidget, Styles):
80
80
  self.imshow_kwargs = imshow_kwargs
81
81
  self.PxToUm = PxToUm
82
82
  self.init_contrast = False
83
+ self.channel_trigger = False
83
84
 
84
85
  self.load_stack() # need to get stack, frame etc
85
86
  self.generate_figure_canvas()
@@ -123,13 +124,8 @@ class StackVisualizer(QWidget, Styles):
123
124
 
124
125
  def locate_image_virtual(self):
125
126
  # Locate the stack of images if provided as a file
126
- self.stack_length = auto_load_number_of_frames(self.stack_path)
127
- if self.stack_length is None:
128
- stack = imread(self.stack_path)
129
- self.stack_length = len(stack)
130
- del stack
131
- gc.collect()
132
127
 
128
+ self.stack_length = auto_load_number_of_frames(self.stack_path)
133
129
  self.mid_time = self.stack_length // 2
134
130
  self.img_num_per_channel = _get_img_num_per_channel(np.arange(self.n_channels), self.stack_length, self.n_channels)
135
131
 
@@ -146,7 +142,7 @@ class StackVisualizer(QWidget, Styles):
146
142
  self.fig, self.ax = plt.subplots(figsize=(5,5),tight_layout=True) #figsize=(5, 5)
147
143
  self.canvas = FigureCanvas(self.fig, title=self.window_title, interactive=True)
148
144
  self.ax.clear()
149
- self.im = self.ax.imshow(self.init_frame, cmap='gray', interpolation='none', **self.imshow_kwargs)
145
+ self.im = self.ax.imshow(self.init_frame, cmap='gray', interpolation='none', zorder=0, **self.imshow_kwargs)
150
146
  if self.PxToUm is not None:
151
147
  scalebar = ScaleBar(self.PxToUm,
152
148
  "um",
@@ -250,12 +246,23 @@ class StackVisualizer(QWidget, Styles):
250
246
  self.last_frame = load_frames(self.img_num_per_channel[self.target_channel, self.stack_length-1],
251
247
  self.stack_path,
252
248
  normalize_input=False).astype(float)[:,:,0]
253
- self.change_frame(self.frame_slider.value())
249
+ self.change_frame_from_channel_switch(self.frame_slider.value())
250
+ self.channel_trigger = False
254
251
  self.init_contrast = False
255
252
 
253
+ def change_frame_from_channel_switch(self, value):
254
+
255
+ self.channel_trigger = True
256
+ self.change_frame(value)
257
+
256
258
  def change_frame(self, value):
259
+
257
260
  # Change the displayed frame based on slider value
258
-
261
+ if self.channel_trigger:
262
+ self.switch_from_channel = True
263
+ else:
264
+ self.switch_from_channel = False
265
+
259
266
  if self.mode=='virtual':
260
267
 
261
268
  self.init_frame = load_frames(self.img_num_per_channel[self.target_channel, value],
@@ -279,7 +286,7 @@ class StackVisualizer(QWidget, Styles):
279
286
 
280
287
  def closeEvent(self, event):
281
288
  # Event handler for closing the widget
282
- self.canvas.close()
289
+ self.canvas.close()
283
290
 
284
291
 
285
292
  class ThresholdedStackVisualizer(StackVisualizer):
@@ -617,6 +624,249 @@ class CellEdgeVisualizer(StackVisualizer):
617
624
 
618
625
  self.edge_labels = contour_of_instance_segmentation(self.init_label, edge_size)
619
626
 
627
+ class SpotDetectionVisualizer(StackVisualizer):
628
+
629
+ def __init__(self, parent_channel_cb=None, parent_diameter_le=None, parent_threshold_le=None, cell_type='targets', labels=None, *args, **kwargs):
630
+
631
+ super().__init__(*args, **kwargs)
632
+
633
+ self.cell_type = cell_type
634
+ self.labels = labels
635
+ self.detection_channel = self.target_channel
636
+ self.parent_channel_cb = parent_channel_cb
637
+ self.parent_diameter_le = parent_diameter_le
638
+ self.parent_threshold_le = parent_threshold_le
639
+ self.spot_sizes = []
640
+
641
+ self.floatValidator = QDoubleValidator()
642
+ self.init_scatter()
643
+ self.generate_detection_channel()
644
+ self.generate_spot_detection_params()
645
+ self.generate_add_measurement_btn()
646
+ self.load_labels()
647
+ self.change_frame(self.mid_time)
648
+
649
+ self.ax.callbacks.connect('xlim_changed', self.update_marker_sizes)
650
+ self.ax.callbacks.connect('ylim_changed', self.update_marker_sizes)
651
+
652
+ self.apply_diam_btn.clicked.connect(self.detect_and_display_spots)
653
+ self.apply_thresh_btn.clicked.connect(self.detect_and_display_spots)
654
+
655
+ self.channels_cb.setCurrentIndex(self.target_channel)
656
+ self.detection_channel_cb.setCurrentIndex(self.target_channel)
657
+
658
+ def update_marker_sizes(self, event=None):
659
+
660
+ # Get axis bounds
661
+ xlim = self.ax.get_xlim()
662
+ ylim = self.ax.get_ylim()
663
+
664
+ # Data-to-pixel scale
665
+ ax_width_in_pixels = self.ax.bbox.width
666
+ ax_height_in_pixels = self.ax.bbox.height
667
+
668
+ x_scale = (xlim[1] - xlim[0]) / ax_width_in_pixels
669
+ y_scale = (ylim[1] - ylim[0]) / ax_height_in_pixels
670
+
671
+ # Choose the smaller scale for square pixels
672
+ scale = min(x_scale, y_scale)
673
+
674
+ # Convert radius_px to data units
675
+ if len(self.spot_sizes)>0:
676
+
677
+ radius_data_units = self.spot_sizes / scale
678
+
679
+ # Convert to scatter `s` size (points squared)
680
+ radius_pts = radius_data_units * (72. / self.fig.dpi )
681
+ size = np.pi * (radius_pts ** 2)
682
+
683
+ # Update scatter sizes
684
+ self.spot_scat.set_sizes(size)
685
+ self.fig.canvas.draw_idle()
686
+
687
+ def init_scatter(self):
688
+ self.spot_scat = self.ax.scatter([],[], s=50, facecolors='none', edgecolors='tab:red',zorder=100)
689
+ self.canvas.canvas.draw()
690
+
691
+ def change_frame(self, value):
692
+
693
+ super().change_frame(value)
694
+ if not self.switch_from_channel:
695
+ self.reset_detection()
696
+
697
+ if self.mode=='virtual':
698
+ self.init_label = imread(self.mask_paths[value])
699
+ self.target_img = load_frames(self.img_num_per_channel[self.detection_channel, value],
700
+ self.stack_path,
701
+ normalize_input=False).astype(float)[:,:,0]
702
+ elif self.mode=='direct':
703
+ self.init_label = self.labels[value,:,:]
704
+ self.target_img = self.stack[value,:,:,self.detection_channel].copy()
705
+
706
+ def detect_and_display_spots(self):
707
+
708
+ self.reset_detection()
709
+ self.control_valid_parameters() # set current diam and threshold
710
+ blobs_filtered = extract_blobs_in_image(self.target_img, self.init_label,threshold=self.thresh, diameter=self.diameter)
711
+ if blobs_filtered is not None:
712
+ self.spot_positions = np.array([[x,y] for y,x,_ in blobs_filtered])
713
+
714
+ self.spot_sizes = np.sqrt(2)*np.array([sig for _,_,sig in blobs_filtered])
715
+ print(f"{self.spot_sizes=}")
716
+ #radius_pts = self.spot_sizes * (self.fig.dpi / 72.0)
717
+ #sizes = np.pi*(radius_pts**2)
718
+
719
+ self.spot_scat.set_offsets(self.spot_positions)
720
+ #self.spot_scat.set_sizes(sizes)
721
+ self.update_marker_sizes()
722
+ self.canvas.canvas.draw()
723
+
724
+ def reset_detection(self):
725
+
726
+ self.ax.scatter([], []).get_offsets()
727
+ empty_offset = np.ma.masked_array([0, 0], mask=True)
728
+ self.spot_scat.set_offsets(empty_offset)
729
+ self.canvas.canvas.draw()
730
+
731
+ def load_labels(self):
732
+
733
+ # Load the cell labels
734
+ if self.labels is not None:
735
+
736
+ if isinstance(self.labels, list):
737
+ self.labels = np.array(self.labels)
738
+
739
+ assert self.labels.ndim==3,'Wrong dimensions for the provided labels, expect TXY'
740
+ assert len(self.labels)==self.stack_length
741
+
742
+ self.mode = 'direct'
743
+ self.init_label = self.labels[self.mid_time,:,:]
744
+ else:
745
+ self.mode = 'virtual'
746
+ assert isinstance(self.stack_path, str)
747
+ assert self.stack_path.endswith('.tif')
748
+ self.locate_labels_virtual()
749
+
750
+ def locate_labels_virtual(self):
751
+ # Locate virtual labels
752
+
753
+ labels_path = str(Path(self.stack_path).parent.parent) + os.sep + f'labels_{self.cell_type}' + os.sep
754
+ self.mask_paths = natsorted(glob(labels_path + '*.tif'))
755
+
756
+ if len(self.mask_paths) == 0:
757
+
758
+ msgBox = QMessageBox()
759
+ msgBox.setIcon(QMessageBox.Critical)
760
+ msgBox.setText("No labels were found for the selected cells. Abort.")
761
+ msgBox.setWindowTitle("Critical")
762
+ msgBox.setStandardButtons(QMessageBox.Ok)
763
+ returnValue = msgBox.exec()
764
+ self.close()
765
+
766
+ self.init_label = imread(self.mask_paths[self.frame_slider.value()])
767
+
768
+ def generate_detection_channel(self):
769
+
770
+ assert self.channel_names is not None
771
+ assert len(self.channel_names)==self.n_channels
772
+
773
+ channel_layout = QHBoxLayout()
774
+ channel_layout.setContentsMargins(15,0,15,0)
775
+ channel_layout.addWidget(QLabel('Detection\nchannel: '), 25)
776
+
777
+ self.detection_channel_cb = QComboBox()
778
+ self.detection_channel_cb.addItems(self.channel_names)
779
+ self.detection_channel_cb.currentIndexChanged.connect(self.set_detection_channel_index)
780
+ channel_layout.addWidget(self.detection_channel_cb, 75)
781
+ self.canvas.layout.addLayout(channel_layout)
782
+
783
+ def set_detection_channel_index(self, value):
784
+
785
+ self.detection_channel = value
786
+ if self.mode == 'direct':
787
+ self.last_frame = self.stack[-1,:,:,self.target_channel]
788
+ elif self.mode == 'virtual':
789
+ self.target_img = load_frames(self.img_num_per_channel[self.detection_channel, self.stack_length-1],
790
+ self.stack_path,
791
+ normalize_input=False).astype(float)[:,:,0]
792
+
793
+ def generate_spot_detection_params(self):
794
+
795
+ self.spot_diam_le = QLineEdit('1')
796
+ self.spot_diam_le.setValidator(self.floatValidator)
797
+ self.apply_diam_btn = QPushButton('Set')
798
+ self.apply_diam_btn.setStyleSheet(self.button_style_sheet_2)
799
+
800
+ self.spot_thresh_le = QLineEdit('0')
801
+ self.spot_thresh_le.setValidator(self.floatValidator)
802
+ self.apply_thresh_btn = QPushButton('Set')
803
+ self.apply_thresh_btn.setStyleSheet(self.button_style_sheet_2)
804
+
805
+ self.spot_diam_le.textChanged.connect(self.control_valid_parameters)
806
+ self.spot_thresh_le.textChanged.connect(self.control_valid_parameters)
807
+
808
+ spot_diam_layout = QHBoxLayout()
809
+ spot_diam_layout.setContentsMargins(15,0,15,0)
810
+ spot_diam_layout.addWidget(QLabel('Spot diameter: '), 25)
811
+ spot_diam_layout.addWidget(self.spot_diam_le, 65)
812
+ spot_diam_layout.addWidget(self.apply_diam_btn, 10)
813
+ self.canvas.layout.addLayout(spot_diam_layout)
814
+
815
+ spot_thresh_layout = QHBoxLayout()
816
+ spot_thresh_layout.setContentsMargins(15,0,15,0)
817
+ spot_thresh_layout.addWidget(QLabel('Detection\nthreshold: '), 25)
818
+ spot_thresh_layout.addWidget(self.spot_thresh_le, 65)
819
+ spot_thresh_layout.addWidget(self.apply_thresh_btn, 10)
820
+ self.canvas.layout.addLayout(spot_thresh_layout)
821
+
822
+ def generate_add_measurement_btn(self):
823
+
824
+ add_hbox = QHBoxLayout()
825
+ self.add_measurement_btn = QPushButton('Add measurement')
826
+ self.add_measurement_btn.clicked.connect(self.set_measurement_in_parent_list)
827
+ self.add_measurement_btn.setIcon(icon(MDI6.plus,color="white"))
828
+ self.add_measurement_btn.setIconSize(QSize(20, 20))
829
+ self.add_measurement_btn.setStyleSheet(self.button_style_sheet)
830
+ add_hbox.addWidget(QLabel(''),33)
831
+ add_hbox.addWidget(self.add_measurement_btn, 33)
832
+ add_hbox.addWidget(QLabel(''),33)
833
+ self.canvas.layout.addLayout(add_hbox)
834
+
835
+ def control_valid_parameters(self):
836
+
837
+ valid_diam = False
838
+ try:
839
+ self.diameter = float(self.spot_diam_le.text().replace(',','.'))
840
+ valid_diam = True
841
+ except:
842
+ valid_diam = False
843
+
844
+ valid_thresh = False
845
+ try:
846
+ self.thresh = float(self.spot_thresh_le.text().replace(',','.'))
847
+ valid_thresh = True
848
+ except:
849
+ valid_thresh = False
850
+
851
+ if valid_diam and valid_thresh:
852
+ self.apply_diam_btn.setEnabled(True)
853
+ self.apply_thresh_btn.setEnabled(True)
854
+ self.add_measurement_btn.setEnabled(True)
855
+ else:
856
+ self.apply_diam_btn.setEnabled(False)
857
+ self.apply_thresh_btn.setEnabled(False)
858
+ self.add_measurement_btn.setEnabled(False)
859
+
860
+ def set_measurement_in_parent_list(self):
861
+
862
+ if self.parent_channel_cb is not None:
863
+ self.parent_channel_cb.setCurrentIndex(self.detection_channel)
864
+ if self.parent_diameter_le is not None:
865
+ self.parent_diameter_le.setText(self.spot_diam_le.text())
866
+ if self.parent_threshold_le is not None:
867
+ self.parent_threshold_le.setText(self.spot_thresh_le.text())
868
+ self.close()
869
+
620
870
  class CellSizeViewer(StackVisualizer):
621
871
 
622
872
  """
celldetective/io.py CHANGED
@@ -14,7 +14,7 @@ from btrack.datasets import cell_config
14
14
  from magicgui import magicgui
15
15
  from csbdeep.io import save_tiff_imagej_compatible
16
16
  from pathlib import Path, PurePath
17
- from shutil import copyfile
17
+ from shutil import copyfile, rmtree
18
18
  from celldetective.utils import ConfigSectionMap, extract_experiment_channels, _extract_labels_from_config, get_zenodo_files, download_zenodo_file
19
19
  import json
20
20
  from skimage.measure import regionprops_table
@@ -69,7 +69,7 @@ def collect_experiment_metadata(pos_path=None, well_path=None):
69
69
  antibodies = get_experiment_antibodies(experiment)
70
70
  pharmaceutical_agents = get_experiment_pharmaceutical_agents(experiment)
71
71
 
72
- return {"pos_path": pos_path, "pos_name": pos_name, "well_path": well_path, "well_name": well_name, "well_nbr": well_nbr, "experiment": experiment, "antibody": antibodies[idx], "concentration": concentrations[idx], "cell_type": cell_types[idx], "pharmaceutical_agent": pharmaceutical_agents[idx]}
72
+ return {"pos_path": pos_path, "position": pos_path, "pos_name": pos_name, "well_path": well_path, "well_name": well_name, "well_nbr": well_nbr, "experiment": experiment, "antibody": antibodies[idx], "concentration": concentrations[idx], "cell_type": cell_types[idx], "pharmaceutical_agent": pharmaceutical_agents[idx]}
73
73
 
74
74
 
75
75
  def get_experiment_wells(experiment):
@@ -674,14 +674,24 @@ def locate_stack(position, prefix='Aligned'):
674
674
  stack_path = glob(position + os.sep.join(['movie', f'{prefix}*.tif']))
675
675
  assert len(stack_path) > 0, f"No movie with prefix {prefix} found..."
676
676
  stack = imread(stack_path[0].replace('\\', '/'))
677
+ stack_length = auto_load_number_of_frames(stack_path[0])
678
+
677
679
  if stack.ndim == 4:
678
680
  stack = np.moveaxis(stack, 1, -1)
679
681
  elif stack.ndim == 3:
680
- stack = stack[:, :, :, np.newaxis]
682
+ if min(stack.shape)!=stack_length:
683
+ channel_axis = np.argmin(stack.shape)
684
+ if channel_axis!=(stack.ndim-1):
685
+ stack = np.moveaxis(stack, channel_axis, -1)
686
+ stack = stack[np.newaxis, :, :, :]
687
+ else:
688
+ stack = stack[:, :, :, np.newaxis]
689
+ elif stack.ndim==2:
690
+ stack = stack[np.newaxis, :, :, np.newaxis]
681
691
 
682
692
  return stack
683
693
 
684
- def locate_labels(position, population='target'):
694
+ def locate_labels(position, population='target', frames=None):
685
695
 
686
696
  """
687
697
 
@@ -722,7 +732,33 @@ def locate_labels(position, population='target'):
722
732
  label_path = natsorted(glob(position + os.sep.join(["labels_targets", "*.tif"])))
723
733
  elif population.lower() == "effector" or population.lower() == "effectors":
724
734
  label_path = natsorted(glob(position + os.sep.join(["labels_effectors", "*.tif"])))
725
- labels = np.array([imread(i.replace('\\', '/')) for i in label_path])
735
+
736
+ label_names = [os.path.split(lbl)[-1] for lbl in label_path]
737
+
738
+ if frames is None:
739
+
740
+ labels = np.array([imread(i.replace('\\', '/')) for i in label_path])
741
+
742
+ elif isinstance(frames, (int,float, np.int_)):
743
+
744
+ tzfill = str(int(frames)).zfill(4)
745
+ idx = label_names.index(f"{tzfill}.tif")
746
+ if idx==-1:
747
+ labels = None
748
+ else:
749
+ labels = np.array(imread(label_path[idx].replace('\\', '/')))
750
+
751
+ elif isinstance(frames, (list,np.ndarray)):
752
+ labels = []
753
+ for f in frames:
754
+ tzfill = str(int(f)).zfill(4)
755
+ idx = label_names.index(f"{tzfill}.tif")
756
+ if idx==-1:
757
+ labels.append(None)
758
+ else:
759
+ labels.append(np.array(imread(label_path[idx].replace('\\', '/'))))
760
+ else:
761
+ print('Frames argument must be None, int or list...')
726
762
 
727
763
  return labels
728
764
 
@@ -773,7 +809,8 @@ def fix_missing_labels(position, population='target', prefix='Aligned'):
773
809
  to_create = all_frames
774
810
  to_create = [str(x).zfill(4)+'.tif' for x in to_create]
775
811
  for file in to_create:
776
- imwrite(os.sep.join([path, file]), template)
812
+ save_tiff_imagej_compatible(os.sep.join([path, file]), template.astype(np.int16), axes='YX')
813
+ #imwrite(os.sep.join([path, file]), template.astype(int))
777
814
 
778
815
 
779
816
  def locate_stack_and_labels(position, prefix='Aligned', population="target"):
@@ -912,6 +949,7 @@ def auto_load_number_of_frames(stack_path):
912
949
  return None
913
950
 
914
951
  stack_path = stack_path.replace('\\','/')
952
+ n_channels=1
915
953
 
916
954
  with TiffFile(stack_path) as tif:
917
955
  try:
@@ -921,7 +959,8 @@ def auto_load_number_of_frames(stack_path):
921
959
  tif_tags[name] = value
922
960
  img_desc = tif_tags["ImageDescription"]
923
961
  attr = img_desc.split("\n")
924
- except:
962
+ n_channels = int(attr[np.argmax([s.startswith("channels") for s in attr])].split("=")[-1])
963
+ except Exception as e:
925
964
  pass
926
965
  try:
927
966
  # Try nframes
@@ -948,6 +987,10 @@ def auto_load_number_of_frames(stack_path):
948
987
  if 'len_movie' not in locals():
949
988
  stack = imread(stack_path)
950
989
  len_movie = len(stack)
990
+ if len_movie==n_channels and stack.ndim==3:
991
+ len_movie = 1
992
+ if stack.ndim==2:
993
+ len_movie = 1
951
994
  del stack
952
995
  gc.collect()
953
996
 
@@ -1585,6 +1628,8 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
1585
1628
 
1586
1629
  for k, sq in enumerate(squares):
1587
1630
  print(f"ROI: {sq}")
1631
+ pad_to_256=False
1632
+
1588
1633
  xmin = int(sq[0, 1])
1589
1634
  xmax = int(sq[2, 1])
1590
1635
  if xmax < xmin:
@@ -1596,8 +1641,9 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
1596
1641
  print(f"{xmin=};{xmax=};{ymin=};{ymax=}")
1597
1642
  frame = viewer.layers['Image'].data[t][xmin:xmax, ymin:ymax]
1598
1643
  if frame.shape[1] < 256 or frame.shape[0] < 256:
1599
- print("crop too small!")
1600
- continue
1644
+ pad_to_256 = True
1645
+ print("Crop too small! Padding with zeros to reach 256*256 pixels...")
1646
+ #continue
1601
1647
  multichannel = [frame]
1602
1648
  for i in range(len(channel_indices) - 1):
1603
1649
  try:
@@ -1605,8 +1651,21 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
1605
1651
  multichannel.append(frame)
1606
1652
  except:
1607
1653
  pass
1608
- multichannel = np.array(multichannel)
1609
- save_tiff_imagej_compatible(annotation_folder + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}_labelled.tif", labels_layer[xmin:xmax,ymin:ymax].astype(np.int16), axes='YX')
1654
+ multichannel = np.array(multichannel)
1655
+ lab = labels_layer[xmin:xmax,ymin:ymax].astype(np.int16)
1656
+ if pad_to_256:
1657
+ shape = multichannel.shape
1658
+ pad_length_x = max([0,256 - multichannel.shape[1]])
1659
+ if pad_length_x>0 and pad_length_x%2==1:
1660
+ pad_length_x += 1
1661
+ pad_length_y = max([0,256 - multichannel.shape[2]])
1662
+ if pad_length_y>0 and pad_length_y%2==1:
1663
+ pad_length_y += 1
1664
+ padded_image = np.array([np.pad(im, ((pad_length_x//2,pad_length_x//2), (pad_length_y//2,pad_length_y//2)), mode='constant') for im in multichannel])
1665
+ padded_label = np.pad(lab,((pad_length_x//2,pad_length_x//2), (pad_length_y//2,pad_length_y//2)), mode='constant')
1666
+ lab = padded_label; multichannel = padded_image;
1667
+
1668
+ save_tiff_imagej_compatible(annotation_folder + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}_labelled.tif", lab, axes='YX')
1610
1669
  save_tiff_imagej_compatible(annotation_folder + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}.tif", multichannel, axes='CYX')
1611
1670
  info = {"spatial_calibration": spatial_calibration, "channels": list(channel_names), 'cell_type': ct, 'antibody': ab, 'concentration': conc, 'pharmaceutical_agent': pa}
1612
1671
  info_name = annotation_folder + f"{exp_name}_{position.split(os.sep)[-2]}_{str(t).zfill(4)}_roi_{xmin}_{xmax}_{ymin}_{ymax}.json"
@@ -1645,6 +1704,7 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
1645
1704
  population += 's'
1646
1705
  output_folder = position + f'labels_{population}{os.sep}'
1647
1706
 
1707
+ print(f"{stack.shape}")
1648
1708
  viewer = napari.Viewer()
1649
1709
  viewer.add_image(stack, channel_axis=-1, colormap=["gray"] * stack.shape[-1])
1650
1710
  viewer.add_labels(labels.astype(int), name='segmentation', opacity=0.4)
@@ -1832,10 +1892,24 @@ def get_segmentation_models_list(mode='targets', return_path=False):
1832
1892
 
1833
1893
  available_models = natsorted(glob(modelpath + '*/'))
1834
1894
  available_models = [m.replace('\\', '/').split('/')[-2] for m in available_models]
1895
+
1896
+ # Auto model cleanup
1897
+ to_remove = []
1898
+ for model in available_models:
1899
+ path = modelpath + model
1900
+ files = glob(path+os.sep+"*")
1901
+ if path+os.sep+"config_input.json" not in files:
1902
+ rmtree(path)
1903
+ to_remove.append(model)
1904
+ for m in to_remove:
1905
+ available_models.remove(m)
1906
+
1907
+
1835
1908
  for rm in repository_models:
1836
1909
  if rm not in available_models:
1837
1910
  available_models.append(rm)
1838
1911
 
1912
+
1839
1913
  if not return_path:
1840
1914
  return available_models
1841
1915
  else: