celldetective 1.1.1.post3__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. celldetective/__init__.py +2 -1
  2. celldetective/__main__.py +17 -0
  3. celldetective/extra_properties.py +62 -34
  4. celldetective/gui/__init__.py +1 -0
  5. celldetective/gui/analyze_block.py +2 -1
  6. celldetective/gui/classifier_widget.py +18 -10
  7. celldetective/gui/control_panel.py +57 -6
  8. celldetective/gui/layouts.py +14 -11
  9. celldetective/gui/neighborhood_options.py +21 -13
  10. celldetective/gui/plot_signals_ui.py +39 -11
  11. celldetective/gui/process_block.py +413 -95
  12. celldetective/gui/retrain_segmentation_model_options.py +17 -4
  13. celldetective/gui/retrain_signal_model_options.py +106 -6
  14. celldetective/gui/signal_annotator.py +110 -30
  15. celldetective/gui/signal_annotator2.py +2708 -0
  16. celldetective/gui/signal_annotator_options.py +3 -1
  17. celldetective/gui/survival_ui.py +15 -6
  18. celldetective/gui/tableUI.py +248 -43
  19. celldetective/io.py +598 -416
  20. celldetective/measure.py +919 -969
  21. celldetective/models/pair_signal_detection/blank +0 -0
  22. celldetective/neighborhood.py +482 -340
  23. celldetective/preprocessing.py +81 -61
  24. celldetective/relative_measurements.py +648 -0
  25. celldetective/scripts/analyze_signals.py +1 -1
  26. celldetective/scripts/measure_cells.py +28 -8
  27. celldetective/scripts/measure_relative.py +103 -0
  28. celldetective/scripts/segment_cells.py +5 -5
  29. celldetective/scripts/track_cells.py +4 -1
  30. celldetective/scripts/train_segmentation_model.py +23 -18
  31. celldetective/scripts/train_signal_model.py +33 -0
  32. celldetective/segmentation.py +67 -29
  33. celldetective/signals.py +402 -8
  34. celldetective/tracking.py +8 -2
  35. celldetective/utils.py +144 -12
  36. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/METADATA +8 -8
  37. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/RECORD +42 -38
  38. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/WHEEL +1 -1
  39. tests/test_segmentation.py +1 -1
  40. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/LICENSE +0 -0
  41. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/entry_points.txt +0 -0
  42. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/top_level.txt +0 -0
@@ -72,6 +72,8 @@ if os.path.exists(instr_path):
72
72
  print("Reading the following instructions: ", instructions)
73
73
  if 'background_correction' in instructions:
74
74
  background_correction = instructions['background_correction']
75
+ else:
76
+ background_correction = None
75
77
 
76
78
  if 'features' in instructions:
77
79
  features = instructions['features']
@@ -145,14 +147,19 @@ except IndexError:
145
147
  # Load trajectories, add centroid if not in trajectory
146
148
  trajectories = pos+os.sep.join(['output','tables', table_name])
147
149
  if os.path.exists(trajectories):
150
+ print('trajectory exists...')
148
151
  trajectories = pd.read_csv(trajectories)
149
152
  if 'TRACK_ID' not in list(trajectories.columns):
150
153
  do_iso_intensities = False
151
154
  intensity_measurement_radii = None
152
155
  if clear_previous:
153
- trajectories = remove_trajectory_measurements(trajectories, column_labels)
156
+ print('No TRACK_ID... Clear previous measurements...')
157
+ trajectories = None #remove_trajectory_measurements(trajectories, column_labels)
158
+ do_features = True
159
+ features += ['centroid']
154
160
  else:
155
161
  if clear_previous:
162
+ print('TRACK_ID found... Clear previous measurements...')
156
163
  trajectories = remove_trajectory_measurements(trajectories, column_labels)
157
164
  else:
158
165
  trajectories = None
@@ -161,11 +168,11 @@ else:
161
168
  do_iso_intensities = False
162
169
 
163
170
 
164
- if (features is not None) and (trajectories is not None):
165
- features = remove_redundant_features(features,
166
- trajectories.columns,
167
- channel_names=channel_names
168
- )
171
+ # if (features is not None) and (trajectories is not None):
172
+ # features = remove_redundant_features(features,
173
+ # trajectories.columns,
174
+ # channel_names=channel_names
175
+ # )
169
176
 
170
177
  len_movie_auto = auto_load_number_of_frames(file)
171
178
  if len_movie_auto is not None:
@@ -187,6 +194,7 @@ if label_path is None:
187
194
  else:
188
195
  do_features = True
189
196
 
197
+
190
198
  #######################################
191
199
  # Loop over all frames and find objects
192
200
  #######################################
@@ -196,6 +204,9 @@ if trajectories is None:
196
204
  print('Use features as a substitute for the trajectory table.')
197
205
  if 'label' not in features:
198
206
  features.append('label')
207
+
208
+
209
+
199
210
  features_log=f'features: {features}'
200
211
  border_distances_log=f'border_distances: {border_distances}'
201
212
  haralick_options_log=f'haralick_options: {haralick_options}'
@@ -237,7 +248,7 @@ def measure_index(indices):
237
248
  column_labels = {'track': "ID", 'time': column_labels['time'], 'x': column_labels['x'],
238
249
  'y': column_labels['y']}
239
250
  feature_table.rename(columns={'centroid-1': 'POSITION_X', 'centroid-0': 'POSITION_Y'}, inplace=True)
240
-
251
+
241
252
  if do_iso_intensities:
242
253
  iso_table = measure_isotropic_intensity(positions_at_t, img, channels=channel_names, intensity_measurement_radii=intensity_measurement_radii, column_labels=column_labels, operations=isotropic_operations, verbose=False)
243
254
 
@@ -249,11 +260,20 @@ def measure_index(indices):
249
260
  elif do_features:
250
261
  measurements_at_t = positions_at_t.merge(feature_table, how='outer', on='class_id',suffixes=('', '_delme'))
251
262
  measurements_at_t = measurements_at_t[[c for c in measurements_at_t.columns if not c.endswith('_delme')]]
252
-
263
+
264
+ center_of_mass_x_cols = [c for c in list(measurements_at_t.columns) if c.endswith('centre_of_mass_x')]
265
+ center_of_mass_y_cols = [c for c in list(measurements_at_t.columns) if c.endswith('centre_of_mass_y')]
266
+ for c in center_of_mass_x_cols:
267
+ measurements_at_t.loc[:,c.replace('_x','_POSITION_X')] = measurements_at_t[c] + measurements_at_t['POSITION_X']
268
+ for c in center_of_mass_y_cols:
269
+ measurements_at_t.loc[:,c.replace('_y','_POSITION_Y')] = measurements_at_t[c] + measurements_at_t['POSITION_Y']
270
+ measurements_at_t = measurements_at_t.drop(columns = center_of_mass_x_cols+center_of_mass_y_cols)
271
+
253
272
  if measurements_at_t is not None:
254
273
  measurements_at_t[column_labels['time']] = t
255
274
  timestep_dataframes.append(measurements_at_t)
256
275
 
276
+
257
277
  # Multithreading
258
278
  indices = list(range(img_num_channels.shape[1]))
259
279
  chunks = np.array_split(indices, n_threads)
@@ -0,0 +1,103 @@
1
+ import argparse
2
+ import os
3
+ import json
4
+ from celldetective.relative_measurements import measure_pair_signals_at_position, update_effector_table, extract_neighborhoods_from_pickles
5
+ from celldetective.utils import ConfigSectionMap, extract_experiment_channels
6
+
7
+ from pathlib import Path, PurePath
8
+
9
+ import pandas as pd
10
+
11
+ from art import tprint
12
+
13
+
14
+ tprint("Measure pairs")
15
+
16
+ parser = argparse.ArgumentParser(description="Measure features and intensities in a multichannel timeseries.",
17
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18
+ parser.add_argument('-p', "--position", required=True, help="Path to the position")
19
+
20
+ args = parser.parse_args()
21
+ process_arguments = vars(args)
22
+ pos = str(process_arguments['position'])
23
+
24
+ instruction_file = os.sep.join(['configs', "neighborhood_instructions.json"])
25
+
26
+ # Locate experiment config
27
+ parent1 = Path(pos).parent
28
+ expfolder = parent1.parent
29
+ config = PurePath(expfolder, Path("config.ini"))
30
+ assert os.path.exists(config), 'The configuration file for the experiment could not be located. Abort.'
31
+ print("Configuration file: ", config)
32
+
33
+ # from exp config fetch spatial calib, channel names
34
+ movie_prefix = ConfigSectionMap(config, "MovieSettings")["movie_prefix"]
35
+ spatial_calibration = float(ConfigSectionMap(config, "MovieSettings")["pxtoum"])
36
+ time_calibration = float(ConfigSectionMap(config, "MovieSettings")["frametomin"])
37
+ len_movie = float(ConfigSectionMap(config, "MovieSettings")["len_movie"])
38
+ channel_names, channel_inneigh_protocoles = extract_experiment_channels(config)
39
+ nbr_channels = len(channel_names)
40
+
41
+ # from tracking instructions, fetch btrack config, features, haralick, clean_traj, idea: fetch custom timeline?
42
+ instr_path = PurePath(expfolder, Path(f"{instruction_file}"))
43
+ previous_pair_table_path = pos + os.sep.join(['output', 'tables', 'trajectories_pairs.csv'])
44
+
45
+ # if os.path.exists(instr_path):
46
+ # print(f"Neighborhood instructions has been successfully located.")
47
+ # with open(instr_path, 'r') as f:
48
+ # instructions = json.load(f)
49
+ # print("Reading the following instructions: ", instructions)
50
+
51
+ # if 'distance' in instructions:
52
+ # distance = instructions['distance'][0]
53
+ # else:
54
+ # distance = None
55
+ # else:
56
+ # print('No measurement instructions found')
57
+ # os.abort()
58
+
59
+ previous_neighborhoods = []
60
+ associated_reference_population = []
61
+
62
+ # if distance is None:
63
+ # print('No measurement could be performed. Check your inputs.')
64
+ # print('Done.')
65
+ # os.abort()
66
+ # #distance = 0
67
+ # else:
68
+ neighborhoods_to_measure = extract_neighborhoods_from_pickles(pos)
69
+ all_df_pairs = []
70
+ if os.path.exists(previous_pair_table_path):
71
+ df_0 = pd.read_csv(previous_pair_table_path)
72
+ previous_neighborhoods = [c.replace('status_','') for c in list(df_0.columns) if c.startswith('status_neighborhood')]
73
+ for n in previous_neighborhoods:
74
+ associated_reference_population.append(df_0.loc[~df_0['status_'+n].isnull(),'reference_population'].values[0])
75
+ print(f'{previous_neighborhoods=} {associated_reference_population=}')
76
+ all_df_pairs.append(df_0)
77
+ for k,neigh_protocol in enumerate(neighborhoods_to_measure):
78
+ if neigh_protocol['description'] not in previous_neighborhoods:
79
+ df_pairs = measure_pair_signals_at_position(pos, neigh_protocol)
80
+ print(f'{df_pairs=}')
81
+ if 'REFERENCE_ID' in list(df_pairs.columns):
82
+ all_df_pairs.append(df_pairs)
83
+ elif neigh_protocol['description'] in previous_neighborhoods and neigh_protocol['reference'] != associated_reference_population[previous_neighborhoods.index(neigh_protocol['description'])]:
84
+ df_pairs = measure_pair_signals_at_position(pos, neigh_protocol)
85
+ if 'REFERENCE_ID' in list(df_pairs.columns):
86
+ all_df_pairs.append(df_pairs)
87
+
88
+ print(f'{len(all_df_pairs)} neighborhood measurements sets were computed...')
89
+
90
+ if len(all_df_pairs)>1:
91
+ print('Merging...')
92
+ df_pairs = all_df_pairs[0]
93
+ for i in range(1,len(all_df_pairs)):
94
+ cols = [c1 for c1,c2 in zip(list(df_pairs.columns), list(all_df_pairs[i].columns)) if c1==c2]
95
+ df_pairs = pd.merge(df_pairs.round(decimals=6), all_df_pairs[i].round(decimals=6), how="outer", on=cols)
96
+ elif len(all_df_pairs)==1:
97
+ df_pairs = all_df_pairs[0]
98
+
99
+ print('Writing table...')
100
+ df_pairs = df_pairs.sort_values(by=['reference_population', 'neighbor_population', 'REFERENCE_ID', 'NEIGHBOR_ID', 'FRAME'])
101
+ df_pairs.to_csv(previous_pair_table_path, index=False)
102
+ print('Done.')
103
+
@@ -129,10 +129,10 @@ img_num_channels = _get_img_num_per_channel(channel_indices, int(len_movie), nbr
129
129
 
130
130
  # If everything OK, prepare output, load models
131
131
  print('Erasing previous segmentation folder.')
132
- if os.path.exists(os.sep.join([pos,label_folder])):
133
- rmtree(os.sep.join([pos,label_folder]))
134
- os.mkdir(os.sep.join([pos,label_folder]))
135
- print(f'Folder {os.sep.join([pos,label_folder])} successfully generated.')
132
+ if os.path.exists(pos+label_folder):
133
+ rmtree(pos+label_folder)
134
+ os.mkdir(pos+label_folder)
135
+ print(f'Folder {pos+label_folder} successfully generated.')
136
136
  log=f'segmentation model: {modelname}\n'
137
137
  with open(pos+f'log_{mode}.json', 'a') as f:
138
138
  f.write(f'{datetime.datetime.now()} SEGMENT \n')
@@ -203,7 +203,7 @@ def segment_index(indices):
203
203
  if Y_pred.shape != template.shape[:2]:
204
204
  Y_pred = resize(Y_pred, template.shape[:2], order=0)
205
205
 
206
- save_tiff_imagej_compatible(os.sep.join([pos,label_folder,f"{str(t).zfill(4)}.tif"]), Y_pred, axes='YX')
206
+ save_tiff_imagej_compatible(pos+os.sep.join([label_folder,f"{str(t).zfill(4)}.tif"]), Y_pred, axes='YX')
207
207
 
208
208
  del f;
209
209
  del template;
@@ -9,7 +9,7 @@ import json
9
9
  from celldetective.io import auto_load_number_of_frames, load_frames, interpret_tracking_configuration
10
10
  from celldetective.utils import extract_experiment_channels, _extract_channel_indices_from_config, _extract_channel_indices, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel, extract_experiment_channels
11
11
  from celldetective.measure import drop_tonal_features, measure_features
12
- from celldetective import track
12
+ from celldetective.tracking import track
13
13
  from pathlib import Path, PurePath
14
14
  from glob import glob
15
15
  from shutil import rmtree
@@ -213,5 +213,8 @@ print(f"napari data successfully saved in {pos+os.sep.join(['output', 'tables'])
213
213
  trajectories.to_csv(pos+os.sep.join(['output', 'tables', table_name]), index=False)
214
214
  print(f"Table {table_name} successfully saved in {os.sep.join(['output', 'tables'])}")
215
215
 
216
+ if os.path.exists(pos+os.sep.join(['output', 'tables', table_name.replace('.csv','.pkl')])):
217
+ os.remove(pos+os.sep.join(['output', 'tables', table_name.replace('.csv','.pkl')]))
218
+
216
219
  del trajectories; del napari_data;
217
220
  gc.collect()
@@ -16,7 +16,8 @@ from celldetective.io import normalize_multichannel
16
16
  from stardist import fill_label_holes
17
17
  from art import tprint
18
18
  import matplotlib.pyplot as plt
19
-
19
+ from distutils.dir_util import copy_tree
20
+ from csbdeep.utils import save_json
20
21
 
21
22
  tprint("Train")
22
23
 
@@ -169,22 +170,17 @@ elif model_type=='stardist':
169
170
 
170
171
  # Predict on subsampled grid for increased efficiency and larger field of view
171
172
  grid = (2,2)
172
- conf = Config2D (
173
+ conf = Config2D(
173
174
  n_rays = n_rays,
174
175
  grid = grid,
175
176
  use_gpu = use_gpu,
176
177
  n_channel_in = n_channel,
177
- unet_dropout = 0.0,
178
- unet_batch_norm = False,
179
- unet_n_conv_per_depth=2,
180
178
  train_learning_rate = learning_rate,
181
179
  train_patch_size = (256,256),
182
180
  train_epochs = epochs,
183
- #train_foreground_only=0.9,
184
- train_loss_weights=(1,0.2),
185
181
  train_reduce_lr = {'factor': 0.1, 'patience': 30, 'min_delta': 0},
186
- unet_n_depth = 3,
187
182
  train_batch_size = batch_size,
183
+ train_steps_per_epoch = int(augmentation_factor*len(X_trn)),
188
184
  )
189
185
 
190
186
  if use_gpu:
@@ -194,19 +190,28 @@ elif model_type=='stardist':
194
190
  if pretrained is None:
195
191
  model = StarDist2D(conf, name=model_name, basedir=target_directory)
196
192
  else:
197
- # files_to_copy = glob(os.sep.join([pretrained, '*']))
198
- # for f in files_to_copy:
199
- # shutil.copy(f, os.sep.join([target_directory, model_name, os.path.split(f)[-1]]))
200
- idx=1
201
- while os.path.exists(os.sep.join([target_directory, model_name])):
202
- model_name = model_name+f'_{idx}'
203
- idx+=1
204
-
205
- shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
193
+
194
+ os.rename(instructions, os.sep.join([target_directory, model_name, 'temp.json']))
195
+ copy_tree(pretrained, os.sep.join([target_directory, model_name]))
196
+
197
+ if os.path.exists(os.sep.join([target_directory, model_name, 'training_instructions.json'])):
198
+ os.remove(os.sep.join([target_directory, model_name, 'training_instructions.json']))
199
+ if os.path.exists(os.sep.join([target_directory, model_name, 'config_input.json'])):
200
+ os.remove(os.sep.join([target_directory, model_name, 'config_input.json']))
201
+ if os.path.exists(os.sep.join([target_directory, model_name, 'logs'+os.sep])):
202
+ shutil.rmtree(os.sep.join([target_directory, model_name, 'logs']))
203
+ os.rename(os.sep.join([target_directory, model_name, 'temp.json']),os.sep.join([target_directory, model_name, 'training_instructions.json']))
204
+
205
+ #shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
206
206
  model = StarDist2D(None, name=model_name, basedir=target_directory)
207
207
  model.config.train_epochs = epochs
208
208
  model.config.train_batch_size = min(len(X_trn),batch_size)
209
- model.config.train_learning_rate = learning_rate
209
+ model.config.train_learning_rate = learning_rate # perf seems bad if lr is changed in transfer
210
+ model.config.use_gpu = use_gpu
211
+ model.config.train_reduce_lr = {'factor': 0.1, 'patience': 10, 'min_delta': 0}
212
+ print(f'{model.config=}')
213
+
214
+ save_json(vars(model.config), os.sep.join([target_directory, model_name, 'config.json']))
210
215
 
211
216
  median_size = calculate_extents(list(Y_trn), np.mean)
212
217
  fov = np.array(model._axes_tile_overlap('YX'))
@@ -12,6 +12,7 @@ import numpy as np
12
12
  import gc
13
13
  from art import tprint
14
14
  from celldetective.signals import SignalDetectionModel
15
+ from celldetective.io import locate_signal_model
15
16
 
16
17
  tprint("Train")
17
18
 
@@ -36,14 +37,46 @@ else:
36
37
  print('The configuration path is not valid. Abort.')
37
38
  os.abort()
38
39
 
40
+ all_classes = []
41
+ for d in threshold_instructions["ds"]:
42
+ datasets = glob(d+os.sep+"*.npy")
43
+ for dd in datasets:
44
+ data = np.load(dd, allow_pickle=True)
45
+ classes = np.unique([ddd["class"] for ddd in data])
46
+ all_classes.extend(classes)
47
+ all_classes = np.unique(all_classes)
48
+ print(all_classes,len(all_classes))
49
+
50
+ n_classes = len(all_classes)
39
51
 
40
52
  model_params = {k:threshold_instructions[k] for k in ('pretrained', 'model_signal_length', 'channel_option', 'n_channels', 'label') if k in threshold_instructions}
53
+ model_params.update({'n_classes': n_classes})
54
+
41
55
  train_params = {k:threshold_instructions[k] for k in ('model_name', 'target_directory', 'channel_option','recompile_pretrained', 'test_split', 'augment', 'epochs', 'learning_rate', 'batch_size', 'validation_split','normalization_percentile','normalization_values','normalization_clip') if k in threshold_instructions}
42
56
 
43
57
  print(f'model params {model_params}')
44
58
  print(f'train params {train_params}')
45
59
 
46
60
  model = SignalDetectionModel(**model_params)
61
+ print(threshold_instructions['ds'])
47
62
  model.fit_from_directory(threshold_instructions['ds'], **train_params)
48
63
 
64
+
65
+ # if neighborhood of interest in training instructions, write it in config!
66
+ if 'neighborhood_of_interest' in threshold_instructions:
67
+ if threshold_instructions['neighborhood_of_interest'] is not None:
68
+
69
+ model_path = locate_signal_model(threshold_instructions['model_name'], path=None, pairs=True)
70
+ complete_path = model_path #+model
71
+ complete_path = rf"{complete_path}"
72
+ model_config_path = os.sep.join([complete_path,'config_input.json'])
73
+ model_config_path = rf"{model_config_path}"
74
+
75
+ f = open(model_config_path)
76
+ config = json.load(f)
77
+ config.update({'neighborhood_of_interest': threshold_instructions['neighborhood_of_interest'], 'reference_population': threshold_instructions['reference_population'], 'neighbor_population': threshold_instructions['neighbor_population']})
78
+ json_string = json.dumps(config)
79
+ with open(model_config_path, 'w') as outfile:
80
+ outfile.write(json_string)
81
+
49
82
  print('Done.')
@@ -20,6 +20,7 @@ from skimage.segmentation import watershed
20
20
  from skimage.feature import peak_local_max
21
21
  from skimage.measure import regionprops_table
22
22
  from skimage.exposure import match_histograms
23
+ from scipy.ndimage import zoom
23
24
  import pandas as pd
24
25
  import subprocess
25
26
 
@@ -27,7 +28,7 @@ import subprocess
27
28
  abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
28
29
 
29
30
  def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_napari=False,
30
- use_gpu=True, time_flat_normalization=False, time_flat_percentiles=(0.0,99.99)):
31
+ use_gpu=True, channel_axis=-1):
31
32
 
32
33
  """
33
34
 
@@ -85,7 +86,10 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
85
86
  if not use_gpu:
86
87
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
87
88
  else:
88
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
89
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
90
+
91
+ if channel_axis != -1:
92
+ stack = np.moveaxis(stack, channel_axis, -1)
89
93
 
90
94
  if channels is not None:
91
95
  assert len(channels)==stack.shape[-1],f'The channel names provided do not match with the expected number of channels in the stack: {stack.shape[-1]}.'
@@ -96,48 +100,83 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
96
100
  required_spatial_calibration = input_config['spatial_calibration']
97
101
  model_type = input_config['model_type']
98
102
 
99
- if 'normalize' in input_config:
100
- normalize = input_config['normalize']
101
- else:
102
- normalize = True
103
+ normalization_percentile = input_config['normalization_percentile']
104
+ normalization_clip = input_config['normalization_clip']
105
+ normalization_values = input_config['normalization_values']
103
106
 
104
107
  if model_type=='cellpose':
105
108
  diameter = input_config['diameter']
106
- if diameter!=30:
107
- required_spatial_calibration = None
109
+ # if diameter!=30:
110
+ # required_spatial_calibration = None
108
111
  cellprob_threshold = input_config['cellprob_threshold']
109
112
  flow_threshold = input_config['flow_threshold']
110
113
 
111
114
  scale = _estimate_scale_factor(spatial_calibration, required_spatial_calibration)
112
115
 
113
116
  if model_type=='stardist':
117
+
114
118
  model = StarDist2D(None, name=model_name, basedir=Path(model_path).parent)
115
- print(f"StarDist model {model_name} successfully loaded")
119
+ model.config.use_gpu = use_gpu
120
+ model.use_gpu = use_gpu
121
+ print(f"StarDist model {model_name} successfully loaded.")
122
+ scale_model = scale
116
123
 
117
124
  elif model_type=='cellpose':
118
- model = CellposeModel(gpu=use_gpu, pretrained_model=model_path+model_path.split('/')[-2], diam_mean=30.0)
125
+
126
+ import torch
127
+ if not use_gpu:
128
+ device = torch.device("cpu")
129
+ else:
130
+ device = torch.device("cuda")
131
+
132
+ model = CellposeModel(gpu=use_gpu, device=device, pretrained_model=model_path+model_path.split('/')[-2], model_type=None, nchan=len(required_channels))
119
133
  if scale is None:
120
134
  scale_model = model.diam_mean / model.diam_labels
121
135
  else:
122
136
  scale_model = scale * model.diam_mean / model.diam_labels
137
+ print(f"Diam mean: {model.diam_mean}; Diam labels: {model.diam_labels}; Final rescaling: {scale_model}...")
138
+ print(f'Cellpose model {model_name} successfully loaded.')
123
139
 
124
140
  labels = []
125
- if (time_flat_normalization)*normalize:
126
- normalization_values = get_stack_normalization_values(stack[:,:,:,channel_indices], percentiles=time_flat_percentiles)
127
- else:
128
- normalization_values = [None]*len(channel_indices)
129
141
 
130
142
  for t in tqdm(range(len(stack)),desc="frame"):
131
143
 
132
144
  # normalize
133
- frame = stack[t,:,:,np.array(channel_indices)]
134
- if np.argmin(frame.shape)!=(frame.ndim-1):
135
- frame = np.moveaxis(frame,np.argmin(frame.shape),-1)
136
- if normalize:
137
- frame = normalize_multichannel(frame, values=normalization_values)
138
-
139
- if scale is not None:
140
- frame = [ndi.zoom(frame[:,:,c].copy(), [scale_model,scale_model], order=3, prefilter=False) for c in range(frame.shape[-1])]
145
+ channel_indices = np.array(channel_indices)
146
+ none_channel_indices = np.where(channel_indices==None)[0]
147
+ channel_indices[channel_indices==None] = 0
148
+ print(channel_indices)
149
+
150
+ frame = stack[t,:,:,channel_indices.astype(int)].astype(float)
151
+ if frame.ndim==2:
152
+ frame = frame[:,:,np.newaxis]
153
+ if frame.ndim==3 and np.array(frame.shape).argmin()==0:
154
+ frame = np.moveaxis(frame,0,-1)
155
+ template = frame.copy()
156
+
157
+ values = []
158
+ percentiles = []
159
+ for k in range(len(normalization_percentile)):
160
+ if normalization_percentile[k]:
161
+ percentiles.append(normalization_values[k])
162
+ values.append(None)
163
+ else:
164
+ percentiles.append(None)
165
+ values.append(normalization_values[k])
166
+
167
+ frame = normalize_multichannel(frame, **{"percentiles": percentiles, 'values': values, 'clip': normalization_clip})
168
+
169
+ if scale_model is not None:
170
+ frame = [zoom(frame[:,:,c].copy(), [scale_model,scale_model], order=3, prefilter=False) for c in range(frame.shape[-1])]
171
+ frame = np.moveaxis(frame,0,-1)
172
+
173
+ for k in range(frame.shape[2]):
174
+ unique_values = np.unique(frame[:,:,k])
175
+ if len(unique_values)==1:
176
+ frame[0,0,k] += 1
177
+
178
+ frame = np.moveaxis([interpolate_nan(frame[:,:,c].copy()) for c in range(frame.shape[-1])],0,-1)
179
+ frame[:,:,none_channel_indices] = 0.
141
180
 
142
181
  if model_type=="stardist":
143
182
 
@@ -145,16 +184,15 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
145
184
  Y_pred = Y_pred.astype(np.uint16)
146
185
 
147
186
  elif model_type=="cellpose":
148
-
149
- Y_pred, _, _ = model.eval(frame, diameter = diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, channels=None, normalize=False)
187
+
188
+ img = np.moveaxis(frame, -1, 0)
189
+ Y_pred, _, _ = model.eval(img, diameter = diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, channels=None, normalize=False)
150
190
  Y_pred = Y_pred.astype(np.uint16)
151
191
 
152
- if scale is not None:
153
- Y_pred = ndi.zoom(Y_pred, [1./scale_model,1./scale_model],order=0)
154
-
155
-
156
192
  if Y_pred.shape != stack[0].shape[:2]:
157
- Y_pred = resize(Y_pred, stack[0].shape, order=0)
193
+ Y_pred = zoom(Y_pred, [1./scale_model,1./scale_model],order=0)
194
+ if Y_pred.shape != template.shape[:2]:
195
+ Y_pred = resize(Y_pred, template.shape[:2], order=0)
158
196
 
159
197
  labels.append(Y_pred)
160
198