celldetective 1.3.9.post4__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. celldetective/__init__.py +0 -3
  2. celldetective/_version.py +1 -1
  3. celldetective/events.py +2 -4
  4. celldetective/extra_properties.py +320 -24
  5. celldetective/gui/InitWindow.py +33 -45
  6. celldetective/gui/__init__.py +1 -0
  7. celldetective/gui/about.py +19 -15
  8. celldetective/gui/analyze_block.py +34 -19
  9. celldetective/gui/base_components.py +23 -0
  10. celldetective/gui/btrack_options.py +26 -34
  11. celldetective/gui/classifier_widget.py +71 -80
  12. celldetective/gui/configure_new_exp.py +113 -17
  13. celldetective/gui/control_panel.py +68 -141
  14. celldetective/gui/generic_signal_plot.py +9 -12
  15. celldetective/gui/gui_utils.py +49 -21
  16. celldetective/gui/json_readers.py +5 -4
  17. celldetective/gui/layouts.py +246 -22
  18. celldetective/gui/measurement_options.py +32 -17
  19. celldetective/gui/neighborhood_options.py +10 -13
  20. celldetective/gui/plot_measurements.py +21 -17
  21. celldetective/gui/plot_signals_ui.py +131 -75
  22. celldetective/gui/process_block.py +180 -123
  23. celldetective/gui/processes/compute_neighborhood.py +594 -0
  24. celldetective/gui/processes/measure_cells.py +5 -0
  25. celldetective/gui/processes/segment_cells.py +27 -6
  26. celldetective/gui/processes/track_cells.py +6 -0
  27. celldetective/gui/retrain_segmentation_model_options.py +12 -20
  28. celldetective/gui/retrain_signal_model_options.py +57 -56
  29. celldetective/gui/seg_model_loader.py +21 -62
  30. celldetective/gui/signal_annotator.py +139 -72
  31. celldetective/gui/signal_annotator2.py +431 -635
  32. celldetective/gui/signal_annotator_options.py +8 -11
  33. celldetective/gui/survival_ui.py +49 -95
  34. celldetective/gui/tableUI.py +28 -25
  35. celldetective/gui/thresholds_gui.py +617 -1221
  36. celldetective/gui/viewers.py +106 -39
  37. celldetective/gui/workers.py +9 -3
  38. celldetective/io.py +73 -27
  39. celldetective/measure.py +63 -27
  40. celldetective/neighborhood.py +342 -268
  41. celldetective/preprocessing.py +25 -17
  42. celldetective/relative_measurements.py +50 -29
  43. celldetective/scripts/analyze_signals.py +4 -1
  44. celldetective/scripts/measure_relative.py +4 -1
  45. celldetective/scripts/segment_cells.py +0 -6
  46. celldetective/scripts/track_cells.py +3 -1
  47. celldetective/scripts/train_segmentation_model.py +7 -4
  48. celldetective/signals.py +29 -14
  49. celldetective/tracking.py +7 -2
  50. celldetective/utils.py +36 -8
  51. {celldetective-1.3.9.post4.dist-info → celldetective-1.4.0.dist-info}/METADATA +24 -16
  52. {celldetective-1.3.9.post4.dist-info → celldetective-1.4.0.dist-info}/RECORD +57 -55
  53. {celldetective-1.3.9.post4.dist-info → celldetective-1.4.0.dist-info}/WHEEL +1 -1
  54. tests/test_qt.py +21 -21
  55. {celldetective-1.3.9.post4.dist-info → celldetective-1.4.0.dist-info}/entry_points.txt +0 -0
  56. {celldetective-1.3.9.post4.dist-info → celldetective-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {celldetective-1.3.9.post4.dist-info → celldetective-1.4.0.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ from celldetective.io import auto_load_number_of_frames, load_frames
3
3
  from celldetective.filters import *
4
4
  from celldetective.segmentation import filter_image, threshold_image
5
5
  from celldetective.measure import contour_of_instance_segmentation, extract_blobs_in_image
6
- from celldetective.utils import _get_img_num_per_channel, estimate_unreliable_edge
6
+ from celldetective.utils import _get_img_num_per_channel, estimate_unreliable_edge, is_integer_array
7
7
  from tifffile import imread
8
8
  import matplotlib.pyplot as plt
9
9
  from pathlib import Path
@@ -11,11 +11,11 @@ from natsort import natsorted
11
11
  from glob import glob
12
12
  import os
13
13
 
14
- from PyQt5.QtWidgets import QWidget, QHBoxLayout, QPushButton, QLabel, QComboBox, QLineEdit, QListWidget, QShortcut
14
+ from PyQt5.QtWidgets import QHBoxLayout, QPushButton, QLabel, QComboBox, QLineEdit, QListWidget, QShortcut
15
15
  from PyQt5.QtCore import Qt, QSize
16
16
  from PyQt5.QtGui import QKeySequence, QDoubleValidator
17
17
  from celldetective.gui.gui_utils import FigureCanvas, center_window, QuickSliderLayout, QHSeperationLine, ThresholdLineEdit, PreprocessingLayout2
18
- from celldetective.gui import Styles
18
+ from celldetective.gui import Styles, CelldetectiveWidget
19
19
  from superqt import QLabeledDoubleSlider, QLabeledSlider, QLabeledDoubleRangeSlider
20
20
  from superqt.fonticon import icon
21
21
  from fonticon_mdi6 import MDI6
@@ -24,7 +24,7 @@ import gc
24
24
  from celldetective.utils import mask_edges
25
25
  from scipy.ndimage import shift
26
26
 
27
- class StackVisualizer(QWidget, Styles):
27
+ class StackVisualizer(CelldetectiveWidget):
28
28
 
29
29
  """
30
30
  A widget for visualizing image stacks with interactive sliders and channel selection.
@@ -92,8 +92,7 @@ class StackVisualizer(QWidget, Styles):
92
92
  self.generate_frame_slider()
93
93
 
94
94
  self.canvas.layout.setContentsMargins(15,15,15,30)
95
- self.setAttribute(Qt.WA_DeleteOnClose)
96
- center_window(self)
95
+ #center_window(self)
97
96
 
98
97
  def show(self):
99
98
  # Display the widget
@@ -136,6 +135,7 @@ class StackVisualizer(QWidget, Styles):
136
135
  self.stack_path,
137
136
  normalize_input=False).astype(float)[:,:,0]
138
137
 
138
+
139
139
  def generate_figure_canvas(self):
140
140
  # Generate the figure canvas for displaying images
141
141
 
@@ -187,6 +187,16 @@ class StackVisualizer(QWidget, Styles):
187
187
  channel_layout.addWidget(self.channels_cb, 75)
188
188
  self.canvas.layout.addLayout(channel_layout)
189
189
 
190
+ def set_contrast_decimals(self):
191
+ if is_integer_array(self.init_frame):
192
+ self.contrast_slider.setDecimals(0)
193
+ self.contrast_slider.setSingleStep(1.0)
194
+ self.contrast_slider.setTickInterval(1.0)
195
+ else:
196
+ self.contrast_slider.setDecimals(3)
197
+ self.contrast_slider.setSingleStep(1.0E-03)
198
+ self.contrast_slider.setTickInterval(1.0E-03)
199
+
190
200
  def generate_contrast_slider(self):
191
201
  # Generate the contrast slider if enabled
192
202
 
@@ -197,15 +207,15 @@ class StackVisualizer(QWidget, Styles):
197
207
  slider_initial_value=[np.nanpercentile(self.init_frame, 0.1),np.nanpercentile(self.init_frame, 99.99)],
198
208
  slider_range=(np.nanmin(self.init_frame),np.nanmax(self.init_frame)),
199
209
  decimal_option=True,
200
- precision=1.0E-05,
210
+ precision=2,
201
211
  )
212
+ self.set_contrast_decimals()
213
+
202
214
  contrast_layout.setContentsMargins(15,0,15,0)
203
215
  self.im.set_clim(vmin=np.nanpercentile(self.init_frame, 0.1),vmax=np.nanpercentile(self.init_frame, 99.99))
204
216
  self.contrast_slider.valueChanged.connect(self.change_contrast)
205
217
  self.canvas.layout.addLayout(contrast_layout)
206
218
 
207
-
208
-
209
219
  def generate_frame_slider(self):
210
220
  # Generate the frame slider if enabled
211
221
 
@@ -250,6 +260,8 @@ class StackVisualizer(QWidget, Styles):
250
260
  self.channel_trigger = False
251
261
  self.init_contrast = False
252
262
 
263
+ self.set_contrast_decimals()
264
+
253
265
  def change_frame_from_channel_switch(self, value):
254
266
 
255
267
  self.channel_trigger = True
@@ -275,15 +287,17 @@ class StackVisualizer(QWidget, Styles):
275
287
  self.im.set_data(self.init_frame)
276
288
 
277
289
  if self.init_contrast:
278
- self.im.autoscale()
279
- I_min, I_max = self.im.get_clim()
280
- self.contrast_slider.setRange(np.nanmin([self.init_frame,self.last_frame]),np.nanmax([self.init_frame,self.last_frame]))
281
- self.contrast_slider.setValue((I_min,I_max))
290
+ imgs = np.array([self.init_frame,self.last_frame])
291
+ vmin = np.nanpercentile(imgs.flatten(), 1.0)
292
+ vmax = np.nanpercentile(imgs.flatten(), 99.99)
293
+ self.contrast_slider.setRange(np.nanmin(imgs),np.nanmax(imgs))
294
+ self.contrast_slider.setValue((vmin,vmax))
295
+ self.im.set_clim(vmin,vmax)
282
296
 
283
297
  if self.create_contrast_slider:
284
298
  self.change_contrast(self.contrast_slider.value())
285
299
 
286
-
300
+
287
301
  def closeEvent(self, event):
288
302
  # Event handler for closing the widget
289
303
  self.canvas.close()
@@ -318,16 +332,27 @@ class ThresholdedStackVisualizer(StackVisualizer):
318
332
  with interactive sliders for threshold and mask opacity adjustment.
319
333
  """
320
334
 
321
- def __init__(self, preprocessing=None, parent_le=None, initial_threshold=5, initial_mask_alpha=0.5, *args, **kwargs):
335
+ def __init__(self, preprocessing=None, parent_le=None, initial_threshold=5, initial_mask_alpha=0.5, show_opacity_slider=True, show_threshold_slider=True, *args, **kwargs):
322
336
  # Initialize the widget and its attributes
323
337
  super().__init__(*args, **kwargs)
324
338
  self.preprocessing = preprocessing
325
339
  self.thresh = initial_threshold
326
340
  self.mask_alpha = initial_mask_alpha
327
341
  self.parent_le = parent_le
328
- self.compute_mask(self.thresh)
329
- self.generate_mask_imshow()
342
+ self.show_opacity_slider = show_opacity_slider
343
+ self.show_threshold_slider = show_threshold_slider
344
+ self.thresholded = False
345
+ self.mask = np.zeros_like(self.init_frame)
346
+ self.thresh_min = 0.0
347
+ self.thresh_max = 30.0
348
+
330
349
  self.generate_threshold_slider()
350
+
351
+ if self.thresh is not None:
352
+ self.compute_mask(self.thresh)
353
+
354
+ self.generate_mask_imshow()
355
+ self.generate_scatter()
331
356
  self.generate_opacity_slider()
332
357
  if isinstance(self.parent_le, QLineEdit):
333
358
  self.generate_apply_btn()
@@ -349,23 +374,32 @@ class ThresholdedStackVisualizer(StackVisualizer):
349
374
  self.close()
350
375
 
351
376
  def generate_mask_imshow(self):
352
- # Generate the mask imshow
377
+ # Generate the mask imshow
378
+
353
379
  self.im_mask = self.ax.imshow(np.ma.masked_where(self.mask==0, self.mask), alpha=self.mask_alpha, interpolation='none')
354
380
  self.canvas.canvas.draw()
355
381
 
382
+ def generate_scatter(self):
383
+ self.scat_markers = self.ax.scatter([], [], color="tab:red")
384
+
356
385
  def generate_threshold_slider(self):
357
386
  # Generate the threshold slider
358
387
  self.threshold_slider = QLabeledDoubleSlider()
388
+ if self.thresh is None:
389
+ init_value = 1.0E5
390
+ else:
391
+ init_value = self.thresh
359
392
  thresh_layout = QuickSliderLayout(label='Threshold: ',
360
393
  slider=self.threshold_slider,
361
- slider_initial_value=self.thresh,
362
- slider_range=(0,30),
394
+ slider_initial_value=init_value,
395
+ slider_range=(self.thresh_min,np.amax([self.thresh_max, init_value])),
363
396
  decimal_option=True,
364
- precision=1.0E-05,
397
+ precision=4,
365
398
  )
366
399
  thresh_layout.setContentsMargins(15,0,15,0)
367
400
  self.threshold_slider.valueChanged.connect(self.change_threshold)
368
- self.canvas.layout.addLayout(thresh_layout)
401
+ if self.show_threshold_slider:
402
+ self.canvas.layout.addLayout(thresh_layout)
369
403
 
370
404
  def generate_opacity_slider(self):
371
405
  # Generate the opacity slider for the mask
@@ -375,11 +409,12 @@ class ThresholdedStackVisualizer(StackVisualizer):
375
409
  slider_initial_value=0.5,
376
410
  slider_range=(0,1),
377
411
  decimal_option=True,
378
- precision=1.0E-03
412
+ precision=3,
379
413
  )
380
414
  opacity_layout.setContentsMargins(15,0,15,0)
381
415
  self.opacity_slider.valueChanged.connect(self.change_mask_opacity)
382
- self.canvas.layout.addLayout(opacity_layout)
416
+ if self.show_opacity_slider:
417
+ self.canvas.layout.addLayout(opacity_layout)
383
418
 
384
419
  def change_mask_opacity(self, value):
385
420
  # Change the opacity of the mask
@@ -390,28 +425,61 @@ class ThresholdedStackVisualizer(StackVisualizer):
390
425
  def change_threshold(self, value):
391
426
  # Change the threshold value
392
427
  self.thresh = value
393
- self.compute_mask(self.thresh)
394
- mask = np.ma.masked_where(self.mask == 0, self.mask)
395
- self.im_mask.set_data(mask)
396
- self.canvas.canvas.draw_idle()
428
+ if self.thresh is not None:
429
+ self.compute_mask(self.thresh)
430
+ mask = np.ma.masked_where(self.mask == 0, self.mask)
431
+ self.im_mask.set_data(mask)
432
+ self.canvas.canvas.draw_idle()
397
433
 
398
434
  def change_frame(self, value):
399
- # Change the displayed frame and update the threshold
435
+ # Change the displayed frame and update the threshold
436
+ if self.thresholded:
437
+ self.init_contrast = True
400
438
  super().change_frame(value)
401
439
  self.change_threshold(self.threshold_slider.value())
440
+ if self.thresholded:
441
+ self.thresholded = False
442
+ self.init_contrast = False
402
443
 
403
444
  def compute_mask(self, threshold_value):
404
445
  # Compute the mask based on the threshold value
405
446
  self.preprocess_image()
406
447
  edge = estimate_unreliable_edge(self.preprocessing)
407
- self.mask = threshold_image(self.processed_image, threshold_value, np.inf, foreground_value=1, edge_exclusion=edge).astype(int)
448
+ if isinstance(threshold_value, (list,np.ndarray,tuple)):
449
+ self.mask = threshold_image(self.processed_image, threshold_value[0], threshold_value[1], foreground_value=1, fill_holes=True, edge_exclusion=edge).astype(int)
450
+ else:
451
+ self.mask = threshold_image(self.processed_image, threshold_value, np.inf, foreground_value=1, fill_holes=True, edge_exclusion=edge).astype(int)
408
452
 
409
453
  def preprocess_image(self):
410
454
  # Preprocess the image before thresholding
411
455
  if self.preprocessing is not None:
412
456
 
413
457
  assert isinstance(self.preprocessing, list)
414
- self.processed_image = filter_image(self.init_frame.copy(),filters=self.preprocessing)
458
+ self.processed_image = filter_image(self.init_frame.copy().astype(float),filters=self.preprocessing)
459
+ min_ = np.amin(self.processed_image)
460
+ max_ = np.amax(self.processed_image)
461
+
462
+ if min_ < self.thresh_min:
463
+ self.thresh_min = min_
464
+ if max_ > self.thresh_max:
465
+ self.thresh_max = max_
466
+
467
+ self.threshold_slider.setRange(self.thresh_min, self.thresh_max)
468
+
469
+ def set_preprocessing(self, activation_protocol):
470
+
471
+ self.preprocessing = activation_protocol
472
+ self.preprocess_image()
473
+
474
+ self.im.set_data(self.processed_image)
475
+ vmin = np.nanpercentile(self.processed_image, 1.0)
476
+ vmax = np.nanpercentile(self.processed_image, 99.99)
477
+ self.contrast_slider.setRange(np.nanmin(self.processed_image),
478
+ np.nanmax(self.processed_image))
479
+ self.contrast_slider.setValue((vmin, vmax))
480
+ self.im.set_clim(vmin,vmax)
481
+ self.canvas.canvas.draw_idle()
482
+ self.thresholded = True
415
483
 
416
484
 
417
485
  class CellEdgeVisualizer(StackVisualizer):
@@ -578,7 +646,7 @@ class CellEdgeVisualizer(StackVisualizer):
578
646
  slider_initial_value=0.5,
579
647
  slider_range=(0,1),
580
648
  decimal_option=True,
581
- precision=1.0E-03
649
+ precision=3,
582
650
  )
583
651
  opacity_layout.setContentsMargins(15,0,15,0)
584
652
  self.opacity_slider.valueChanged.connect(self.change_mask_opacity)
@@ -952,7 +1020,7 @@ class CellSizeViewer(StackVisualizer):
952
1020
  def generate_circle(self):
953
1021
  # Generate the circle for visualization
954
1022
 
955
- self.circ = plt.Circle((self.init_frame.shape[0]//2,self.init_frame.shape[1]//2), self.diameter//2, ec="tab:red",fill=False)
1023
+ self.circ = plt.Circle((self.init_frame.shape[0]//2,self.init_frame.shape[1]//2), self.diameter//2 / self.PxToUm, ec="tab:red",fill=False)
956
1024
  self.ax.add_patch(self.circ)
957
1025
 
958
1026
  self.ax.callbacks.connect('xlim_changed',self.on_xlims_or_ylims_change)
@@ -978,7 +1046,7 @@ class CellSizeViewer(StackVisualizer):
978
1046
  if self.set_radius_in_list:
979
1047
  val = int(self.diameter_slider.value()//2)
980
1048
  else:
981
- val = int(self.diameter_slider.value())
1049
+ val = int(self.diameter_slider.value())
982
1050
 
983
1051
  self.parent_list_widget.addItems([str(val)])
984
1052
  self.close()
@@ -1017,7 +1085,7 @@ class CellSizeViewer(StackVisualizer):
1017
1085
  slider_initial_value=self.diameter,
1018
1086
  slider_range=self.diameter_slider_range,
1019
1087
  decimal_option=True,
1020
- precision=1.0E-05,
1088
+ precision=5,
1021
1089
  )
1022
1090
  diameter_layout.setContentsMargins(15,0,15,0)
1023
1091
  self.diameter_slider.valueChanged.connect(self.change_diameter)
@@ -1025,9 +1093,8 @@ class CellSizeViewer(StackVisualizer):
1025
1093
 
1026
1094
  def change_diameter(self, value):
1027
1095
  # Change the diameter of the circle
1028
-
1029
1096
  self.diameter = value
1030
- self.circ.set_radius(self.diameter//2)
1097
+ self.circ.set_radius(self.diameter//2 / self.PxToUm)
1031
1098
  self.canvas.canvas.draw_idle()
1032
1099
 
1033
1100
 
@@ -1080,7 +1147,7 @@ class ChannelOffsetViewer(StackVisualizer):
1080
1147
  slider_initial_value=0.5,
1081
1148
  slider_range=(0,1.0),
1082
1149
  decimal_option=True,
1083
- precision=1.0E-05,
1150
+ precision=5,
1084
1151
  )
1085
1152
  alpha_layout.setContentsMargins(15,0,15,0)
1086
1153
  self.overlay_alpha_slider.valueChanged.connect(self.change_alpha_overlay)
@@ -1097,7 +1164,7 @@ class ChannelOffsetViewer(StackVisualizer):
1097
1164
  slider_initial_value=[np.nanpercentile(self.overlay_init_frame, 0.1),np.nanpercentile(self.overlay_init_frame, 99.99)],
1098
1165
  slider_range=(np.nanmin(self.overlay_init_frame),np.nanmax(self.overlay_init_frame)),
1099
1166
  decimal_option=True,
1100
- precision=1.0E-05,
1167
+ precision=5,
1101
1168
  )
1102
1169
  contrast_layout.setContentsMargins(15,0,15,0)
1103
1170
  self.im_overlay.set_clim(vmin=np.nanpercentile(self.overlay_init_frame, 0.1),vmax=np.nanpercentile(self.overlay_init_frame, 99.99))
@@ -1,18 +1,24 @@
1
1
  from multiprocessing import Queue
2
- from PyQt5.QtWidgets import QDialog, QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QLabel, QProgressBar
2
+ from PyQt5.QtWidgets import QPushButton, QVBoxLayout, QHBoxLayout, QLabel, QProgressBar
3
3
  from PyQt5.QtCore import QRunnable, QObject, pyqtSignal, QThreadPool, QSize, Qt
4
+
5
+ from celldetective.gui.base_components import CelldetectiveDialog
4
6
  from celldetective.gui.gui_utils import center_window
7
+ from celldetective.gui import Styles
5
8
  import time
6
9
  import math
7
10
 
8
- class ProgressWindow(QDialog):
11
+ class ProgressWindow(CelldetectiveDialog):
9
12
 
10
13
  def __init__(self, process=None, parent_window=None, title="", position_info=True, process_args=None):
11
- QDialog.__init__(self)
14
+
15
+ super().__init__()
16
+ #QDialog.__init__(self)
12
17
 
13
18
  self.setWindowTitle(f'{title} Progress')
14
19
  self.__process = process
15
20
  self.parent_window = parent_window
21
+
16
22
  self.position_info = position_info
17
23
  if self.position_info:
18
24
  self.pos_name = self.parent_window.pos_name
celldetective/io.py CHANGED
@@ -29,7 +29,7 @@ from celldetective.utils import interpolate_nan_multichannel, _estimate_scale_fa
29
29
 
30
30
  from stardist import fill_label_holes
31
31
  from skimage.transform import resize
32
-
32
+ import re
33
33
 
34
34
  def extract_experiment_from_well(well_path):
35
35
 
@@ -596,6 +596,17 @@ def get_experiment_pharmaceutical_agents(experiment, dtype=str):
596
596
  return np.array([dtype(c) for c in pharmaceutical_agents])
597
597
 
598
598
 
599
+ def get_experiment_populations(experiment, dtype=str):
600
+
601
+ config = get_config(experiment)
602
+ populations_str = ConfigSectionMap(config, "Populations")
603
+ if populations_str is not None:
604
+ populations = populations_str['populations'].split(',')
605
+ else:
606
+ populations = ['effectors','targets']
607
+ return list([dtype(c) for c in populations])
608
+
609
+
599
610
  def interpret_wells_and_positions(experiment, well_option, position_option):
600
611
  """
601
612
  Interpret well and position options for a given experiment.
@@ -1165,6 +1176,9 @@ def locate_labels(position, population='target', frames=None):
1165
1176
  label_path = natsorted(glob(position + os.sep.join(["labels_targets", "*.tif"])))
1166
1177
  elif population.lower() == "effector" or population.lower() == "effectors":
1167
1178
  label_path = natsorted(glob(position + os.sep.join(["labels_effectors", "*.tif"])))
1179
+ else:
1180
+ label_path = natsorted(glob(position + os.sep.join([f"labels_{population}", "*.tif"])))
1181
+
1168
1182
 
1169
1183
  label_names = [os.path.split(lbl)[-1] for lbl in label_path]
1170
1184
 
@@ -1242,6 +1256,9 @@ def fix_missing_labels(position, population='target', prefix='Aligned'):
1242
1256
  elif population.lower() == "effector" or population.lower() == "effectors":
1243
1257
  label_path = natsorted(glob(position + os.sep.join(["labels_effectors", "*.tif"])))
1244
1258
  path = position + os.sep + "labels_effectors"
1259
+ else:
1260
+ label_path = natsorted(glob(position + os.sep.join([f"labels_{population}", "*.tif"])))
1261
+ path = position + os.sep + f"labels_{population}"
1245
1262
 
1246
1263
  if label_path!=[]:
1247
1264
  #path = os.path.split(label_path[0])[0]
@@ -1348,6 +1365,9 @@ def load_tracking_data(position, prefix="Aligned", population="target"):
1348
1365
  trajectories = pd.read_csv(position + os.sep.join(['output', 'tables', 'trajectories_targets.csv']))
1349
1366
  elif population.lower() == "effector" or population.lower() == "effectors":
1350
1367
  trajectories = pd.read_csv(position + os.sep.join(['output', 'tables', 'trajectories_effectors.csv']))
1368
+ else:
1369
+ trajectories = pd.read_csv(position + os.sep.join(['output', 'tables', f'trajectories_{population}.csv']))
1370
+
1351
1371
 
1352
1372
  stack, labels = locate_stack_and_labels(position, prefix=prefix, population=population)
1353
1373
 
@@ -1941,10 +1961,10 @@ def relabel_segmentation(labels, df, exclude_nans=True, column_labels={'track':
1941
1961
 
1942
1962
  def rewrite_labels(indices):
1943
1963
 
1944
- all_track_ids = df[column_labels['track']].unique()
1964
+ all_track_ids = df[column_labels['track']].dropna().unique()
1945
1965
 
1946
1966
  for t in tqdm(indices):
1947
-
1967
+
1948
1968
  f = int(t)
1949
1969
  cells = df.loc[df[column_labels['frame']] == f, [column_labels['track'], column_labels['label']]].to_numpy()
1950
1970
  tracks_at_t = list(cells[:,0])
@@ -1974,15 +1994,23 @@ def relabel_segmentation(labels, df, exclude_nans=True, column_labels={'track':
1974
1994
 
1975
1995
  loc_i, loc_j = np.where(labels[f] == identities[k])
1976
1996
  track_id = tracks_at_t[k]
1977
- new_labels[f, loc_i, loc_j] = round(track_id)
1997
+
1998
+ if track_id==track_id:
1999
+ new_labels[f, loc_i, loc_j] = round(track_id)
1978
2000
 
1979
2001
  # Multithreading
1980
- indices = list(df[column_labels['frame']].unique())
2002
+ indices = list(df[column_labels['frame']].dropna().unique())
1981
2003
  chunks = np.array_split(indices, n_threads)
1982
2004
 
1983
- with concurrent.futures.ThreadPoolExecutor() as executor:
1984
- executor.map(rewrite_labels, chunks)
1985
-
2005
+ with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
2006
+
2007
+ results = executor.map(rewrite_labels, chunks) #list(map(lambda x: executor.submit(self.parallel_job, x), chunks))
2008
+ try:
2009
+ for i,return_value in enumerate(results):
2010
+ print(f"Thread {i} output check: ",return_value)
2011
+ except Exception as e:
2012
+ print("Exception: ", e)
2013
+
1986
2014
  print("\nDone.")
1987
2015
 
1988
2016
  return new_labels
@@ -2088,6 +2116,7 @@ def tracks_to_btrack(df, exclude_nans=False):
2088
2116
  graph = {}
2089
2117
  if exclude_nans:
2090
2118
  df.dropna(subset='class_id',inplace=True)
2119
+ df.dropna(subset='TRACK_ID',inplace=True)
2091
2120
 
2092
2121
  df["z"] = 0.
2093
2122
  data = df[["TRACK_ID","FRAME","z","POSITION_Y","POSITION_X"]].to_numpy()
@@ -2345,6 +2374,11 @@ def load_napari_data(position, prefix="Aligned", population="target", return_sta
2345
2374
  napari_data = np.load(position+os.sep.join(['output', 'tables', 'napari_effector_trajectories.npy']), allow_pickle=True)
2346
2375
  else:
2347
2376
  napari_data = None
2377
+ else:
2378
+ if os.path.exists(position+os.sep.join(['output', 'tables', f'napari_{population}_trajectories.npy'])):
2379
+ napari_data = np.load(position+os.sep.join(['output', 'tables', f'napari_{population}_trajectories.npy']), allow_pickle=True)
2380
+ else:
2381
+ napari_data = None
2348
2382
 
2349
2383
  if napari_data is not None:
2350
2384
  data = napari_data.item()['data']
@@ -2484,6 +2518,9 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
2484
2518
 
2485
2519
  def export_labels():
2486
2520
  labels_layer = viewer.layers['segmentation'].data
2521
+ if not os.path.exists(output_folder):
2522
+ os.mkdir(output_folder)
2523
+
2487
2524
  for t, im in enumerate(tqdm(labels_layer)):
2488
2525
 
2489
2526
  try:
@@ -2630,12 +2667,9 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
2630
2667
  return export_annotation()
2631
2668
 
2632
2669
  stack, labels = locate_stack_and_labels(position, prefix=prefix, population=population)
2633
-
2634
- if not population.endswith('s'):
2635
- population += 's'
2636
2670
  output_folder = position + f'labels_{population}{os.sep}'
2671
+ print(f"Shape of the loaded image stack: {stack.shape}...")
2637
2672
 
2638
- print(f"{stack.shape}")
2639
2673
  viewer = napari.Viewer()
2640
2674
  viewer.add_image(stack, channel_axis=-1, colormap=["gray"] * stack.shape[-1])
2641
2675
  viewer.add_labels(labels.astype(int), name='segmentation', opacity=0.4)
@@ -2669,6 +2703,8 @@ def control_segmentation_napari(position, prefix='Aligned', population="target",
2669
2703
  del labels
2670
2704
  gc.collect()
2671
2705
 
2706
+ print("napari viewer was successfully closed...")
2707
+
2672
2708
  def correct_annotation(filename):
2673
2709
 
2674
2710
  """
@@ -2818,21 +2854,31 @@ def control_tracking_table(position, calibration=1, prefix="Aligned", population
2818
2854
 
2819
2855
 
2820
2856
  def get_segmentation_models_list(mode='targets', return_path=False):
2821
- if mode == 'targets':
2822
- modelpath = os.sep.join(
2823
- [os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], "celldetective", "models",
2824
- "segmentation_targets", os.sep])
2825
- repository_models = get_zenodo_files(cat=os.sep.join(["models", "segmentation_targets"]))
2826
- elif mode == 'effectors':
2827
- modelpath = os.sep.join(
2828
- [os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], "celldetective", "models",
2829
- "segmentation_effectors", os.sep])
2830
- repository_models = get_zenodo_files(cat=os.sep.join(["models", "segmentation_effectors"]))
2831
- elif mode == 'generic':
2832
- modelpath = os.sep.join(
2857
+
2858
+ modelpath = os.sep.join(
2833
2859
  [os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], "celldetective", "models",
2834
- "segmentation_generic", os.sep])
2835
- repository_models = get_zenodo_files(cat=os.sep.join(["models", "segmentation_generic"]))
2860
+ f"segmentation_{mode}", os.sep])
2861
+ if not os.path.exists(modelpath):
2862
+ os.mkdir(modelpath)
2863
+ repository_models = []
2864
+ else:
2865
+ repository_models = get_zenodo_files(cat=os.sep.join(["models", f"segmentation_{mode}"]))
2866
+
2867
+ # if mode == 'targets':
2868
+ # modelpath = os.sep.join(
2869
+ # [os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], "celldetective", "models",
2870
+ # "segmentation_targets", os.sep])
2871
+ # repository_models = get_zenodo_files(cat=os.sep.join(["models", "segmentation_targets"]))
2872
+ # elif mode == 'effectors':
2873
+ # modelpath = os.sep.join(
2874
+ # [os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], "celldetective", "models",
2875
+ # "segmentation_effectors", os.sep])
2876
+ # repository_models = get_zenodo_files(cat=os.sep.join(["models", "segmentation_effectors"]))
2877
+ # elif mode == 'generic':
2878
+ # modelpath = os.sep.join(
2879
+ # [os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], "celldetective", "models",
2880
+ # "segmentation_generic", os.sep])
2881
+ # repository_models = get_zenodo_files(cat=os.sep.join(["models", "segmentation_generic"]))
2836
2882
 
2837
2883
  available_models = natsorted(glob(modelpath + '*/'))
2838
2884
  available_models = [m.replace('\\', '/').split('/')[-2] for m in available_models]
@@ -3266,7 +3312,7 @@ def normalize_multichannel(multichannel_frame, percentiles=None,
3266
3312
 
3267
3313
  return np.moveaxis(mf_new,0,-1)
3268
3314
 
3269
- def load_frames(img_nums, stack_path, scale=None, normalize_input=True, dtype=float, normalize_kwargs={"percentiles": (0.,99.99)}):
3315
+ def load_frames(img_nums, stack_path, scale=None, normalize_input=True, dtype=np.float64, normalize_kwargs={"percentiles": (0.,99.99)}):
3270
3316
 
3271
3317
  """
3272
3318
  Loads and optionally normalizes and rescales specified frames from a stack located at a given path.