celldetective 1.3.6.post1__py3-none-any.whl → 1.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. celldetective/_version.py +1 -1
  2. celldetective/events.py +4 -0
  3. celldetective/filters.py +11 -2
  4. celldetective/gui/InitWindow.py +23 -9
  5. celldetective/gui/control_panel.py +19 -11
  6. celldetective/gui/generic_signal_plot.py +5 -0
  7. celldetective/gui/gui_utils.py +2 -2
  8. celldetective/gui/help/DL-segmentation-strategy.json +17 -17
  9. celldetective/gui/help/Threshold-vs-DL.json +11 -11
  10. celldetective/gui/help/cell-populations.json +5 -5
  11. celldetective/gui/help/exp-structure.json +15 -15
  12. celldetective/gui/help/feature-btrack.json +5 -5
  13. celldetective/gui/help/neighborhood.json +7 -7
  14. celldetective/gui/help/prefilter-for-segmentation.json +7 -7
  15. celldetective/gui/help/preprocessing.json +19 -19
  16. celldetective/gui/help/propagate-classification.json +7 -7
  17. celldetective/gui/neighborhood_options.py +1 -1
  18. celldetective/gui/plot_signals_ui.py +13 -9
  19. celldetective/gui/process_block.py +63 -14
  20. celldetective/gui/retrain_segmentation_model_options.py +21 -8
  21. celldetective/gui/retrain_signal_model_options.py +12 -2
  22. celldetective/gui/signal_annotator.py +9 -0
  23. celldetective/gui/signal_annotator2.py +25 -17
  24. celldetective/gui/styles.py +1 -0
  25. celldetective/gui/tableUI.py +1 -1
  26. celldetective/gui/workers.py +136 -0
  27. celldetective/io.py +54 -28
  28. celldetective/measure.py +112 -14
  29. celldetective/scripts/measure_cells.py +36 -46
  30. celldetective/scripts/segment_cells.py +35 -78
  31. celldetective/scripts/segment_cells_thresholds.py +21 -22
  32. celldetective/scripts/track_cells.py +43 -32
  33. celldetective/segmentation.py +16 -62
  34. celldetective/signals.py +11 -7
  35. celldetective/utils.py +587 -67
  36. {celldetective-1.3.6.post1.dist-info → celldetective-1.3.7.dist-info}/METADATA +1 -1
  37. {celldetective-1.3.6.post1.dist-info → celldetective-1.3.7.dist-info}/RECORD +41 -40
  38. {celldetective-1.3.6.post1.dist-info → celldetective-1.3.7.dist-info}/LICENSE +0 -0
  39. {celldetective-1.3.6.post1.dist-info → celldetective-1.3.7.dist-info}/WHEEL +0 -0
  40. {celldetective-1.3.6.post1.dist-info → celldetective-1.3.7.dist-info}/entry_points.txt +0 -0
  41. {celldetective-1.3.6.post1.dist-info → celldetective-1.3.7.dist-info}/top_level.txt +0 -0
@@ -6,21 +6,17 @@ import argparse
6
6
  import datetime
7
7
  import os
8
8
  import json
9
- from stardist.models import StarDist2D
10
- from cellpose.models import CellposeModel
11
- from celldetective.io import locate_segmentation_model, auto_load_number_of_frames, load_frames
12
- from celldetective.utils import interpolate_nan, _estimate_scale_factor, _extract_channel_indices_from_config, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel
9
+ from celldetective.io import locate_segmentation_model, auto_load_number_of_frames, extract_position_name, _load_frames_to_segment, _check_label_dims
10
+ from celldetective.utils import _prep_stardist_model, _prep_cellpose_model, _rescale_labels, _segment_image_with_stardist_model,_segment_image_with_cellpose_model,_get_normalize_kwargs_from_config, _estimate_scale_factor, _extract_channel_indices_from_config, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel
13
11
  from pathlib import Path, PurePath
14
12
  from glob import glob
15
13
  from shutil import rmtree
16
14
  from tqdm import tqdm
17
15
  import numpy as np
18
- from skimage.transform import resize
19
16
  from csbdeep.io import save_tiff_imagej_compatible
20
17
  import gc
21
18
  from art import tprint
22
- from scipy.ndimage import zoom
23
-
19
+ import concurrent.futures
24
20
 
25
21
  tprint("Segment")
26
22
 
@@ -44,6 +40,7 @@ if use_gpu=='True' or use_gpu=='true' or use_gpu=='1':
44
40
  n_threads = 1 # avoid misbehavior on GPU with multithreading
45
41
  else:
46
42
  use_gpu = False
43
+ #n_threads = 1 # force 1 threads since all CPUs seem to be in use anyway
47
44
 
48
45
  if not use_gpu:
49
46
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
@@ -60,20 +57,22 @@ parent1 = Path(pos).parent
60
57
  expfolder = parent1.parent
61
58
  config = PurePath(expfolder,Path("config.ini"))
62
59
  assert os.path.exists(config),'The configuration file for the experiment could not be located. Abort.'
60
+
61
+ print(f"Position: {extract_position_name(pos)}...")
63
62
  print("Configuration file: ",config)
63
+ print(f"Population: {mode}...")
64
64
 
65
65
  ####################################
66
66
  # Check model requirements #########
67
67
  ####################################
68
68
 
69
69
  modelpath = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],"models"])
70
- print(modelpath)
71
70
  model_complete_path = locate_segmentation_model(modelname)
72
71
  if model_complete_path is None:
73
72
  print('Model could not be found. Abort.')
74
73
  os.abort()
75
74
  else:
76
- print(f'Model successfully located in {model_complete_path}')
75
+ print(f'Model path: {model_complete_path}...')
77
76
 
78
77
  # load config
79
78
  assert os.path.exists(model_complete_path+"config_input.json"),'The configuration for the inputs to the model could not be located. Abort.'
@@ -86,11 +85,9 @@ required_channels = input_config["channels"]
86
85
  channel_indices = _extract_channel_indices_from_config(config, required_channels)
87
86
  print(f'Required channels: {required_channels} located at channel indices {channel_indices}.')
88
87
  required_spatial_calibration = input_config['spatial_calibration']
89
- print(f'Expected spatial calibration is {required_spatial_calibration}.')
88
+ print(f'Spatial calibration expected by the model: {required_spatial_calibration}...')
90
89
 
91
- normalization_percentile = input_config['normalization_percentile']
92
- normalization_clip = input_config['normalization_clip']
93
- normalization_values = input_config['normalization_values']
90
+ normalize_kwargs = _get_normalize_kwargs_from_config(input_config)
94
91
 
95
92
  model_type = input_config['model_type']
96
93
 
@@ -117,18 +114,19 @@ if model_type=='cellpose':
117
114
  flow_threshold = input_config['flow_threshold']
118
115
 
119
116
  scale = _estimate_scale_factor(spatial_calibration, required_spatial_calibration)
120
- print(f"Scale = {scale}...")
117
+ print(f"Scale: {scale}...")
121
118
 
122
119
  nbr_channels = _extract_nbr_channels_from_config(config)
123
- print(f'Number of channels in the input movie: {nbr_channels}')
120
+ #print(f'Number of channels in the input movie: {nbr_channels}')
124
121
  img_num_channels = _get_img_num_per_channel(channel_indices, int(len_movie), nbr_channels)
125
122
 
126
123
  # If everything OK, prepare output, load models
127
- print('Erasing previous segmentation folder.')
128
124
  if os.path.exists(pos+label_folder):
125
+ print('Erasing the previous labels folder...')
129
126
  rmtree(pos+label_folder)
130
127
  os.mkdir(pos+label_folder)
131
- print(f'Folder {pos+label_folder} successfully generated.')
128
+ print(f'Labels folder successfully generated...')
129
+
132
130
  log=f'segmentation model: {modelname}\n'
133
131
  with open(pos+f'log_{mode}.json', 'a') as f:
134
132
  f.write(f'{datetime.datetime.now()} SEGMENT \n')
@@ -137,93 +135,52 @@ with open(pos+f'log_{mode}.json', 'a') as f:
137
135
 
138
136
  # Loop over all frames and segment
139
137
  def segment_index(indices):
140
- global scale
141
138
 
142
139
  if model_type=='stardist':
143
- model = StarDist2D(None, name=modelname, basedir=Path(model_complete_path).parent)
144
- model.config.use_gpu = use_gpu
145
- model.use_gpu = use_gpu
146
- print(f"StarDist model {modelname} successfully loaded.")
147
- scale_model = scale
140
+ model, scale_model = _prep_stardist_model(modelname, Path(model_complete_path).parent, use_gpu=use_gpu, scale=scale)
148
141
 
149
142
  elif model_type=='cellpose':
150
-
151
- import torch
152
- if not use_gpu:
153
- device = torch.device("cpu")
154
- else:
155
- device = torch.device("cuda")
156
-
157
- model = CellposeModel(gpu=use_gpu, device=device, pretrained_model=model_complete_path+modelname, model_type=None, nchan=len(required_channels)) #diam_mean=30.0,
158
- if scale is None:
159
- scale_model = model.diam_mean / model.diam_labels
160
- else:
161
- scale_model = scale * model.diam_mean / model.diam_labels
162
- print(f"Diam mean: {model.diam_mean}; Diam labels: {model.diam_labels}; Final rescaling: {scale_model}...")
163
- print(f'Cellpose model {modelname} successfully loaded.')
143
+ model, scale_model = _prep_cellpose_model(modelname, model_complete_path, use_gpu=use_gpu, n_channels=len(required_channels), scale=scale)
164
144
 
165
145
  for t in tqdm(indices,desc="frame"):
166
-
167
- # Load channels at time t
168
- values = []
169
- percentiles = []
170
- for k in range(len(normalization_percentile)):
171
- if normalization_percentile[k]:
172
- percentiles.append(normalization_values[k])
173
- values.append(None)
174
- else:
175
- percentiles.append(None)
176
- values.append(normalization_values[k])
177
-
178
- f = load_frames(img_num_channels[:,t], file, scale=scale_model, normalize_input=True, normalize_kwargs={"percentiles": percentiles, 'values': values, 'clip': normalization_clip})
179
- f = np.moveaxis([interpolate_nan(f[:,:,c].copy()) for c in range(f.shape[-1])],0,-1)
180
-
181
- if np.any(img_num_channels[:,t]==-1):
182
- f[:,:,np.where(img_num_channels[:,t]==-1)[0]] = 0.
183
-
184
146
 
185
- if model_type=="stardist":
186
- Y_pred, details = model.predict_instances(f, n_tiles=model._guess_n_tiles(f), show_tile_progress=False, verbose=False)
187
- Y_pred = Y_pred.astype(np.uint16)
147
+ f = _load_frames_to_segment(file, img_num_channels[:,t], scale_model=scale_model, normalize_kwargs=normalize_kwargs)
188
148
 
149
+ if model_type=="stardist":
150
+ Y_pred = _segment_image_with_stardist_model(f, model=model, return_details=False)
189
151
  elif model_type=="cellpose":
190
-
191
- img = np.moveaxis(f, -1, 0)
192
- Y_pred, _, _ = model.eval(img, diameter = diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, channels=None, normalize=False)
193
- Y_pred = Y_pred.astype(np.uint16)
152
+ Y_pred = _segment_image_with_cellpose_model(f, model=model, diameter=diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold)
194
153
 
195
154
  if scale is not None:
196
- Y_pred = zoom(Y_pred, [1./scale_model,1./scale_model],order=0)
155
+ Y_pred = _rescale_labels(Y_pred, scale_model=scale_model)
197
156
 
198
- template = load_frames(0,file,scale=1,normalize_input=False)
199
- if Y_pred.shape != template.shape[:2]:
200
- Y_pred = resize(Y_pred, template.shape[:2], order=0)
157
+ Y_pred = _check_label_dims(Y_pred, file)
201
158
 
202
159
  save_tiff_imagej_compatible(pos+os.sep.join([label_folder,f"{str(t).zfill(4)}.tif"]), Y_pred, axes='YX')
203
160
 
204
161
  del f;
205
- del template;
206
162
  del Y_pred;
207
163
  gc.collect()
208
164
 
165
+ del model;
166
+ gc.collect()
209
167
 
210
- import concurrent.futures
168
+ return
169
+
170
+
171
+ print(f"Starting the segmentation with {n_threads} thread(s) and GPU={use_gpu}...")
211
172
 
212
173
  # Multithreading
213
174
  indices = list(range(img_num_channels.shape[1]))
214
175
  chunks = np.array_split(indices, n_threads)
215
176
 
216
177
  with concurrent.futures.ThreadPoolExecutor() as executor:
217
- executor.map(segment_index, chunks)
218
-
219
- # threads = []
220
- # for i in range(n_threads):
221
- # thread_i = threading.Thread(target=segment_index, args=[chunks[i]])
222
- # threads.append(thread_i)
223
- # for th in threads:
224
- # th.start()
225
- # for th in threads:
226
- # th.join()
178
+ results = executor.map(segment_index, chunks)
179
+ try:
180
+ for i,return_value in enumerate(results):
181
+ print(f"Thread {i} output check: ",return_value)
182
+ except Exception as e:
183
+ print("Exception: ", e)
227
184
 
228
185
  print('Done.')
229
186
 
@@ -5,7 +5,7 @@ Copright © 2022 Laboratoire Adhesion et Inflammation, Authored by Remy Torro.
5
5
  import argparse
6
6
  import os
7
7
  import json
8
- from celldetective.io import auto_load_number_of_frames, load_frames
8
+ from celldetective.io import auto_load_number_of_frames, load_frames, extract_position_name
9
9
  from celldetective.segmentation import segment_frame_from_thresholds
10
10
  from celldetective.utils import _extract_channel_indices_from_config, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel, extract_experiment_channels
11
11
  from pathlib import Path, PurePath
@@ -16,6 +16,7 @@ import numpy as np
16
16
  from csbdeep.io import save_tiff_imagej_compatible
17
17
  import gc
18
18
  from art import tprint
19
+ import concurrent.futures
19
20
 
20
21
  tprint("Segment")
21
22
 
@@ -48,8 +49,6 @@ else:
48
49
  print('The configuration path is not valid. Abort.')
49
50
  os.abort()
50
51
 
51
- print('The following instructions were successfully loaded: ', threshold_instructions)
52
-
53
52
  if mode.lower()=="target" or mode.lower()=="targets":
54
53
  label_folder = "labels_targets"
55
54
  elif mode.lower()=="effector" or mode.lower()=="effectors":
@@ -60,12 +59,14 @@ parent1 = Path(pos).parent
60
59
  expfolder = parent1.parent
61
60
  config = PurePath(expfolder,Path("config.ini"))
62
61
  assert os.path.exists(config),'The configuration file for the experiment could not be located. Abort.'
63
- print("Configuration file: ",config)
64
62
 
63
+ print(f"Position: {extract_position_name(pos)}...")
64
+ print("Configuration file: ",config)
65
+ print(f"Population: {mode}...")
65
66
 
66
67
  channel_indices = _extract_channel_indices_from_config(config, required_channels)
67
68
  # need to abort if channel not found
68
- print(f'Required channels: {required_channels} located at channel indices {channel_indices}.')
69
+ print(f'Required channels: {required_channels} located at channel indices {channel_indices}...')
69
70
 
70
71
  threshold_instructions.update({'target_channel': channel_indices[0]})
71
72
 
@@ -86,15 +87,15 @@ if len_movie_auto is not None:
86
87
  len_movie = len_movie_auto
87
88
 
88
89
  nbr_channels = _extract_nbr_channels_from_config(config)
89
- print(f'Number of channels in the input movie: {nbr_channels}')
90
+ #print(f'Number of channels in the input movie: {nbr_channels}')
90
91
  img_num_channels = _get_img_num_per_channel(np.arange(nbr_channels), len_movie, nbr_channels)
91
92
 
92
93
  # If everything OK, prepare output, load models
93
- print('Erasing previous segmentation folder.')
94
94
  if os.path.exists(os.sep.join([pos,label_folder])):
95
+ print('Erasing the previous labels folder...')
95
96
  rmtree(os.sep.join([pos,label_folder]))
96
97
  os.mkdir(os.sep.join([pos,label_folder]))
97
- print(f'Folder {os.sep.join([pos,label_folder])} successfully generated.')
98
+ print(f'Labels folder successfully generated...')
98
99
 
99
100
  if equalize:
100
101
  f_reference = load_frames(img_num_channels[:,equalize_time], file, scale=None, normalize_input=False)
@@ -103,7 +104,7 @@ else:
103
104
  f_reference = None
104
105
 
105
106
  threshold_instructions.update({'equalize_reference': f_reference})
106
- print(threshold_instructions)
107
+ print(f"Instructions: {threshold_instructions}...")
107
108
 
108
109
  # Loop over all frames and segment
109
110
  def segment_index(indices):
@@ -119,27 +120,25 @@ def segment_index(indices):
119
120
  del mask;
120
121
  gc.collect()
121
122
 
122
- import concurrent.futures
123
+ return
124
+
125
+
126
+ print(f"Starting the segmentation with {n_threads} thread(s)...")
123
127
 
124
128
  # Multithreading
125
129
  indices = list(range(img_num_channels.shape[1]))
126
130
  chunks = np.array_split(indices, n_threads)
127
131
 
128
132
  with concurrent.futures.ThreadPoolExecutor() as executor:
129
- executor.map(segment_index, chunks)
130
-
131
- # indices = list(range(img_num_channels.shape[1]))
132
- # chunks = np.array_split(indices, n_threads)
133
- # threads = []
134
- # for i in range(n_threads):
135
- # thread_i = threading.Thread(target=segment_index, args=[chunks[i]])
136
- # threads.append(thread_i)
137
- # for th in threads:
138
- # th.start()
139
- # for th in threads:
140
- # th.join()
133
+ results = executor.map(segment_index, chunks)
134
+ try:
135
+ for i,return_value in enumerate(results):
136
+ print(f"Thread {i} output check: ",return_value)
137
+ except Exception as e:
138
+ print("Exception: ", e)
141
139
 
142
140
  print('Done.')
141
+
143
142
  gc.collect()
144
143
 
145
144
 
@@ -6,8 +6,8 @@ import argparse
6
6
  import datetime
7
7
  import os
8
8
  import json
9
- from celldetective.io import auto_load_number_of_frames, load_frames, interpret_tracking_configuration
10
- from celldetective.utils import extract_experiment_channels, ConfigSectionMap, _get_img_num_per_channel, extract_experiment_channels
9
+ from celldetective.io import auto_load_number_of_frames, interpret_tracking_configuration, extract_position_name
10
+ from celldetective.utils import _mask_intensity_measurements, extract_experiment_channels, ConfigSectionMap, _get_img_num_per_channel, extract_experiment_channels
11
11
  from celldetective.measure import drop_tonal_features, measure_features
12
12
  from celldetective.tracking import track
13
13
  from pathlib import Path, PurePath
@@ -19,7 +19,8 @@ import gc
19
19
  import os
20
20
  from natsort import natsorted
21
21
  from art import tprint
22
- from tifffile import imread
22
+ import concurrent.futures
23
+
23
24
 
24
25
  tprint("Track")
25
26
 
@@ -59,6 +60,10 @@ expfolder = parent1.parent
59
60
  config = PurePath(expfolder,Path("config.ini"))
60
61
  assert os.path.exists(config),'The configuration file for the experiment could not be located. Abort.'
61
62
 
63
+ print(f"Position: {extract_position_name(pos)}...")
64
+ print("Configuration file: ",config)
65
+ print(f"Population: {mode}...")
66
+
62
67
  # from exp config fetch spatial calib, channel names
63
68
  movie_prefix = ConfigSectionMap(config,"MovieSettings")["movie_prefix"]
64
69
  spatial_calibration = float(ConfigSectionMap(config,"MovieSettings")["pxtoum"])
@@ -71,9 +76,10 @@ channel_names, channel_indices = extract_experiment_channels(config)
71
76
  nbr_channels = len(channel_names)
72
77
 
73
78
  # from tracking instructions, fetch btrack config, features, haralick, clean_traj, idea: fetch custom timeline?
79
+ print('Looking for tracking instruction file...')
74
80
  instr_path = PurePath(expfolder,Path(f"{instruction_file}"))
75
81
  if os.path.exists(instr_path):
76
- print(f"Tracking instructions for the {mode} population have been successfully loaded...")
82
+ print(f"Tracking instruction file successfully loaded...")
77
83
  with open(instr_path, 'r') as f:
78
84
  instructions = json.load(f)
79
85
  btrack_config = interpret_tracking_configuration(instructions['btrack_config_path'])
@@ -107,8 +113,6 @@ if os.path.exists(instr_path):
107
113
  memory = None
108
114
  if 'memory' in instructions:
109
115
  memory = instructions['memory']
110
-
111
-
112
116
  else:
113
117
  print('Tracking instructions could not be located... Using a standard bTrack motion model instead...')
114
118
  btrack_config = interpret_tracking_configuration(None)
@@ -169,51 +173,58 @@ if not btrack_option:
169
173
 
170
174
 
171
175
  def measure_index(indices):
176
+
177
+ props = []
178
+
172
179
  for t in tqdm(indices,desc="frame"):
173
180
 
174
181
  # Load channels at time t
175
- img = load_frames(img_num_channels[:,t], file, scale=None, normalize_input=False)
176
- lbl = imread(label_path[t])
182
+ img = _load_frames_to_measure(file, indices=img_num_channels[:,t])
183
+ lbl = locate_labels(pos, population=mode, frames=t)
184
+ if lbl is None:
185
+ continue
177
186
 
178
187
  df_props = measure_features(img, lbl, features = features+['centroid'], border_dist=None,
179
188
  channels=channel_names, haralick_options=haralick_options, verbose=False,
180
189
  )
181
190
  df_props.rename(columns={'centroid-1': 'x', 'centroid-0': 'y'},inplace=True)
182
191
  df_props['t'] = int(t)
183
- timestep_dataframes.append(df_props)
192
+
193
+ props.append(df_props)
194
+
195
+ return props
196
+
197
+ print(f"Measuring features with {n_threads} thread(s)...")
184
198
 
185
199
  # Multithreading
186
200
  indices = list(range(img_num_channels.shape[1]))
187
201
  chunks = np.array_split(indices, n_threads)
188
202
 
189
- import concurrent.futures
190
-
203
+ timestep_dataframes = []
191
204
  with concurrent.futures.ThreadPoolExecutor() as executor:
192
- executor.map(measure_index, chunks)
205
+ results = executor.map(measure_index, chunks)
206
+ try:
207
+ for i,return_value in enumerate(results):
208
+ print(f"Thread {i} completed...")
209
+ timestep_dataframes.extend(return_value)
210
+ except Exception as e:
211
+ print("Exception: ", e)
193
212
 
194
- # threads = []
195
- # for i in range(n_threads):
196
- # thread_i = threading.Thread(target=measure_index, args=[chunks[i]])
197
- # threads.append(thread_i)
198
- # for th in threads:
199
- # th.start()
200
- # for th in threads:
201
- # th.join()
213
+ print('Features successfully measured...')
202
214
 
203
215
  df = pd.concat(timestep_dataframes)
204
216
  df.reset_index(inplace=True, drop=True)
205
217
 
206
- if mask_channels is not None:
207
- cols_to_drop = []
208
- for mc in mask_channels:
209
- columns = df.columns
210
- col_contains = [mc in c for c in columns]
211
- to_remove = np.array(columns)[np.array(col_contains)]
212
- cols_to_drop.extend(to_remove)
213
- if len(cols_to_drop)>0:
214
- df = df.drop(cols_to_drop, axis=1)
218
+ df = _mask_intensity_measurements(df, mask_channels)
215
219
 
216
220
  # do tracking
221
+ if btrack_option:
222
+ tracker = 'bTrack'
223
+ else:
224
+ tracker = 'trackpy'
225
+
226
+ print(f"Start the tracking step using the {tracker} tracker...")
227
+
217
228
  trajectories, napari_data = track(None,
218
229
  configuration=btrack_config,
219
230
  objects=df,
@@ -228,14 +239,14 @@ trajectories, napari_data = track(None,
228
239
  search_range=search_range,
229
240
  memory=memory,
230
241
  )
242
+ print(f"Tracking successfully performed...")
231
243
 
232
244
  # out trajectory table, create POSITION_X_um, POSITION_Y_um, TIME_min (new ones)
233
- # Save napari data
245
+ # Save napari data # deprecated, should disappear progressively
234
246
  np.save(pos+os.sep.join(['output', 'tables', napari_name]), napari_data, allow_pickle=True)
235
- print(f"napari data successfully saved in {pos+os.sep.join(['output', 'tables'])}")
236
247
 
237
248
  trajectories.to_csv(pos+os.sep.join(['output', 'tables', table_name]), index=False)
238
- print(f"Table {table_name} successfully saved in {os.sep.join(['output', 'tables'])}")
249
+ print(f"Trajectory table successfully exported in {os.sep.join(['output', 'tables'])}...")
239
250
 
240
251
  if os.path.exists(pos+os.sep.join(['output', 'tables', table_name.replace('.csv','.pkl')])):
241
252
  os.remove(pos+os.sep.join(['output', 'tables', table_name.replace('.csv','.pkl')]))
@@ -8,12 +8,9 @@ from .utils import _estimate_scale_factor, _extract_channel_indices
8
8
  from pathlib import Path
9
9
  from tqdm import tqdm
10
10
  import numpy as np
11
- from stardist.models import StarDist2D
12
- from cellpose.models import CellposeModel
13
- from skimage.transform import resize
14
- from celldetective.io import _view_on_napari, locate_labels, locate_stack, _view_on_napari
11
+ from celldetective.io import _view_on_napari, locate_labels, locate_stack, _view_on_napari, _check_label_dims
15
12
  from celldetective.filters import * #rework this to give a name
16
- from celldetective.utils import rename_intensity_column, mask_edges, estimate_unreliable_edge
13
+ from celldetective.utils import interpolate_nan_multichannel,_rearrange_multichannel_frame, _fix_no_contrast, zoom_multiframes, _rescale_labels, rename_intensity_column, mask_edges, _prep_stardist_model, _prep_cellpose_model, estimate_unreliable_edge,_get_normalize_kwargs_from_config, _segment_image_with_stardist_model, _segment_image_with_cellpose_model
17
14
  from stardist import fill_label_holes
18
15
  import scipy.ndimage as ndi
19
16
  from skimage.segmentation import watershed
@@ -99,9 +96,7 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
99
96
  required_spatial_calibration = input_config['spatial_calibration']
100
97
  model_type = input_config['model_type']
101
98
 
102
- normalization_percentile = input_config['normalization_percentile']
103
- normalization_clip = input_config['normalization_clip']
104
- normalization_values = input_config['normalization_values']
99
+ normalize_kwargs = _get_normalize_kwargs_from_config(input_config)
105
100
 
106
101
  if model_type=='cellpose':
107
102
  diameter = input_config['diameter']
@@ -116,28 +111,10 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
116
111
  print(f"{spatial_calibration=} {required_spatial_calibration=} Scale = {scale}...")
117
112
 
118
113
  if model_type=='stardist':
119
-
120
- model = StarDist2D(None, name=model_name, basedir=Path(model_path).parent)
121
- model.config.use_gpu = use_gpu
122
- model.use_gpu = use_gpu
123
- print(f"StarDist model {model_name} successfully loaded.")
124
- scale_model = scale
114
+ model, scale_model = _prep_stardist_model(model_name, Path(model_path).parent, use_gpu=use_gpu, scale=scale)
125
115
 
126
116
  elif model_type=='cellpose':
127
-
128
- import torch
129
- if not use_gpu:
130
- device = torch.device("cpu")
131
- else:
132
- device = torch.device("cuda")
133
-
134
- model = CellposeModel(gpu=use_gpu, device=device, pretrained_model=model_path+model_path.split('/')[-2], model_type=None, nchan=len(required_channels))
135
- if scale is None:
136
- scale_model = model.diam_mean / model.diam_labels
137
- else:
138
- scale_model = scale * model.diam_mean / model.diam_labels
139
- print(f"Diam mean: {model.diam_mean}; Diam labels: {model.diam_labels}; Final rescaling: {scale_model}...")
140
- print(f'Cellpose model {model_name} successfully loaded.')
117
+ model, scale_model = _prep_cellpose_model(model_path.split('/')[-2], model_path, use_gpu=use_gpu, n_channels=len(required_channels), scale=scale)
141
118
 
142
119
  labels = []
143
120
 
@@ -149,11 +126,7 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
149
126
  channel_indices[channel_indices==None] = 0
150
127
 
151
128
  frame = stack[t]
152
- #frame = stack[t,:,:,channel_indices.astype(int)].astype(float)
153
- if frame.ndim==2:
154
- frame = frame[:,:,np.newaxis]
155
- if frame.ndim==3 and np.array(frame.shape).argmin()==0:
156
- frame = np.moveaxis(frame,0,-1)
129
+ frame = _rearrange_multichannel_frame(frame).astype(float)
157
130
 
158
131
  frame_to_segment = np.zeros((frame.shape[0], frame.shape[1], len(required_channels))).astype(float)
159
132
  for ch in channel_intersection:
@@ -162,45 +135,25 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
162
135
  frame = frame_to_segment
163
136
  template = frame.copy()
164
137
 
165
- values = []
166
- percentiles = []
167
- for k in range(len(normalization_percentile)):
168
- if normalization_percentile[k]:
169
- percentiles.append(normalization_values[k])
170
- values.append(None)
171
- else:
172
- percentiles.append(None)
173
- values.append(normalization_values[k])
174
-
175
- frame = normalize_multichannel(frame, **{"percentiles": percentiles, 'values': values, 'clip': normalization_clip})
138
+ frame = normalize_multichannel(frame, **normalize_kwargs)
176
139
 
177
140
  if scale_model is not None:
178
- frame = [zoom(frame[:,:,c].copy(), [scale_model,scale_model], order=3, prefilter=False) for c in range(frame.shape[-1])]
179
- frame = np.moveaxis(frame,0,-1)
180
-
181
- for k in range(frame.shape[2]):
182
- unique_values = np.unique(frame[:,:,k])
183
- if len(unique_values)==1:
184
- frame[0,0,k] += 1
141
+ frame = zoom_multiframes(frame, scale_model)
185
142
 
186
- frame = np.moveaxis([interpolate_nan(frame[:,:,c].copy()) for c in range(frame.shape[-1])],0,-1)
143
+ frame = _fix_no_contrast(frame)
144
+ frame = interpolate_nan_multichannel(frame)
187
145
  frame[:,:,none_channel_indices] = 0.
188
146
 
189
147
  if model_type=="stardist":
190
-
191
- Y_pred, details = model.predict_instances(frame, n_tiles=model._guess_n_tiles(frame), show_tile_progress=False, verbose=False)
192
- Y_pred = Y_pred.astype(np.uint16)
148
+ Y_pred = _segment_image_with_stardist_model(frame, model=model, return_details=False)
193
149
 
194
150
  elif model_type=="cellpose":
195
-
196
- img = np.moveaxis(frame, -1, 0)
197
- Y_pred, _, _ = model.eval(img, diameter = diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, channels=None, normalize=False)
198
- Y_pred = Y_pred.astype(np.uint16)
151
+ Y_pred = _segment_image_with_cellpose_model(frame, model=model, diameter=diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold)
199
152
 
200
153
  if Y_pred.shape != stack[0].shape[:2]:
201
- Y_pred = zoom(Y_pred, [1./scale_model,1./scale_model],order=0)
202
- if Y_pred.shape != template.shape[:2]:
203
- Y_pred = resize(Y_pred, template.shape[:2], order=0)
154
+ Y_pred = _rescale_labels(Y_pred, scale_model)
155
+
156
+ Y_pred = _check_label_dims(Y_pred, template=template)
204
157
 
205
158
  labels.append(Y_pred)
206
159
 
@@ -315,6 +268,7 @@ def segment_frame_from_thresholds(frame, target_channel=0, thresholds=None, equa
315
268
  """
316
269
 
317
270
  img = frame[:,:,target_channel]
271
+ img = interpolate_nan(img)
318
272
  if equalize_reference is not None:
319
273
  img = match_histograms(img, equalize_reference)
320
274
  img_mc = frame.copy()