celldetective 1.3.6.post2__py3-none-any.whl → 1.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. celldetective/_version.py +1 -1
  2. celldetective/events.py +4 -0
  3. celldetective/gui/InitWindow.py +23 -9
  4. celldetective/gui/control_panel.py +19 -11
  5. celldetective/gui/generic_signal_plot.py +5 -0
  6. celldetective/gui/help/DL-segmentation-strategy.json +17 -17
  7. celldetective/gui/help/Threshold-vs-DL.json +11 -11
  8. celldetective/gui/help/cell-populations.json +5 -5
  9. celldetective/gui/help/exp-structure.json +15 -15
  10. celldetective/gui/help/feature-btrack.json +5 -5
  11. celldetective/gui/help/neighborhood.json +7 -7
  12. celldetective/gui/help/prefilter-for-segmentation.json +7 -7
  13. celldetective/gui/help/preprocessing.json +19 -19
  14. celldetective/gui/help/propagate-classification.json +7 -7
  15. celldetective/gui/plot_signals_ui.py +13 -9
  16. celldetective/gui/process_block.py +63 -14
  17. celldetective/gui/retrain_segmentation_model_options.py +21 -8
  18. celldetective/gui/retrain_signal_model_options.py +12 -2
  19. celldetective/gui/signal_annotator.py +9 -0
  20. celldetective/gui/signal_annotator2.py +8 -0
  21. celldetective/gui/styles.py +1 -0
  22. celldetective/gui/tableUI.py +1 -1
  23. celldetective/gui/workers.py +136 -0
  24. celldetective/io.py +53 -27
  25. celldetective/measure.py +112 -14
  26. celldetective/scripts/measure_cells.py +10 -35
  27. celldetective/scripts/segment_cells.py +15 -62
  28. celldetective/scripts/segment_cells_thresholds.py +1 -2
  29. celldetective/scripts/track_cells.py +16 -19
  30. celldetective/segmentation.py +16 -62
  31. celldetective/signals.py +11 -7
  32. celldetective/utils.py +587 -67
  33. {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/METADATA +1 -1
  34. {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/RECORD +38 -37
  35. {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/LICENSE +0 -0
  36. {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/WHEEL +0 -0
  37. {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/entry_points.txt +0 -0
  38. {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/top_level.txt +0 -0
celldetective/measure.py CHANGED
@@ -311,7 +311,7 @@ def measure_features(img, label, features=['area', 'intensity_mean'], channels=N
311
311
 
312
312
  if isinstance(features, list):
313
313
  features = features.copy()
314
-
314
+
315
315
  if features is None:
316
316
  features = []
317
317
 
@@ -986,6 +986,69 @@ def blob_detection(image, label, diameter, threshold=0., channel_name=None, targ
986
986
 
987
987
  return detections
988
988
 
989
+
990
+ # def blob_detectionv0(image, label, threshold, diameter):
991
+ # """
992
+ # Perform blob detection on an image based on labeled regions.
993
+
994
+ # Parameters:
995
+ # - image (numpy.ndarray): The input image data.
996
+ # - label (numpy.ndarray): An array specifying labeled regions in the image.
997
+ # - threshold (float): The threshold value for blob detection.
998
+ # - diameter (float): The expected diameter of blobs.
999
+
1000
+ # Returns:
1001
+ # - dict: A dictionary containing information about detected blobs.
1002
+
1003
+ # This function performs blob detection on an image based on labeled regions. It iterates over each labeled region
1004
+ # and detects blobs within the region using the Difference of Gaussians (DoG) method. Detected blobs are filtered
1005
+ # based on the specified threshold and expected diameter. The function returns a dictionary containing the number of
1006
+ # detected blobs and their mean intensity for each labeled region.
1007
+
1008
+ # Example:
1009
+ # >>> image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1010
+ # >>> label = np.array([[0, 1, 1], [2, 2, 0], [3, 3, 0]])
1011
+ # >>> threshold = 0.1
1012
+ # >>> diameter = 5.0
1013
+ # >>> result = blob_detection(image, label, threshold, diameter)
1014
+ # >>> print(result)
1015
+ # {1: [1, 4.0], 2: [0, nan], 3: [0, nan]}
1016
+
1017
+ # Note:
1018
+ # - Blobs are detected using the Difference of Gaussians (DoG) method.
1019
+ # - Detected blobs are filtered based on the specified threshold and expected diameter.
1020
+ # - The returned dictionary contains information about the number of detected blobs and their mean intensity
1021
+ # for each labeled region.
1022
+ # """
1023
+ # blob_labels = {}
1024
+ # dilated_image = ndimage.grey_dilation(label, footprint=disk(10))
1025
+ # for mask_index in np.unique(label):
1026
+ # if mask_index == 0:
1027
+ # continue
1028
+ # removed_background = image.copy()
1029
+ # one_mask = label.copy()
1030
+ # one_mask[np.where(label != mask_index)] = 0
1031
+ # dilated_copy = dilated_image.copy()
1032
+ # dilated_copy[np.where(dilated_image != mask_index)] = 0
1033
+ # removed_background[np.where(dilated_copy == 0)] = 0
1034
+ # min_sigma = (1 / (1 + math.sqrt(2))) * diameter
1035
+ # max_sigma = math.sqrt(2) * min_sigma
1036
+ # blobs = blob_dog(removed_background, threshold=threshold, min_sigma=min_sigma,
1037
+ # max_sigma=max_sigma)
1038
+
1039
+ # mask = np.array([one_mask[int(y), int(x)] != 0 for y, x, r in blobs])
1040
+ # if not np.any(mask):
1041
+ # continue
1042
+ # blobs_filtered = blobs[mask]
1043
+ # binary_blobs = np.zeros_like(label)
1044
+ # for blob in blobs_filtered:
1045
+ # y, x, r = blob
1046
+ # rr, cc = dsk((y, x), r, shape=binary_blobs.shape)
1047
+ # binary_blobs[rr, cc] = 1
1048
+ # spot_intensity = regionprops_table(binary_blobs, removed_background, ['intensity_mean'])
1049
+ # blob_labels[mask_index] = [blobs_filtered.shape[0], spot_intensity['intensity_mean'][0]]
1050
+ # return blob_labels
1051
+
989
1052
  ### Classification ####
990
1053
 
991
1054
  def estimate_time(df, class_attr, model='step_function', class_of_interest=[2], r2_threshold=0.5):
@@ -1163,13 +1226,15 @@ def classify_transient_events(data, class_attr, pre_event=None):
1163
1226
  assert 'class_'+pre_event in cols,"Pre-event class does not seem to be a valid column in the DataFrame..."
1164
1227
 
1165
1228
  stat_col = class_attr.replace('class','status')
1229
+ continuous_stat_col = stat_col.replace('status_','smooth_status_')
1230
+ df[continuous_stat_col] = df[stat_col].copy()
1166
1231
 
1167
1232
  for tid,track in df.groupby(sort_cols):
1168
1233
 
1169
1234
  indices = track[class_attr].index
1170
1235
 
1171
1236
  if pre_event is not None:
1172
-
1237
+
1173
1238
  if track['class_'+pre_event].values[0]==1:
1174
1239
  df.loc[indices, class_attr] = np.nan
1175
1240
  df.loc[indices, stat_col] = np.nan
@@ -1180,7 +1245,8 @@ def classify_transient_events(data, class_attr, pre_event=None):
1180
1245
  indices_pre = track.loc[track['FRAME']<=t_pre_event,class_attr].index
1181
1246
  df.loc[indices_pre, stat_col] = np.nan # set to NaN all statuses before pre-event
1182
1247
  track.loc[track['FRAME']<=t_pre_event, stat_col] = np.nan
1183
-
1248
+ track.loc[track['FRAME']<=t_pre_event, continuous_stat_col] = np.nan
1249
+
1184
1250
  status = track[stat_col].to_numpy()
1185
1251
  timeline = track['FRAME'].to_numpy()
1186
1252
  timeline_safe = timeline[status==status]
@@ -1189,24 +1255,35 @@ def classify_transient_events(data, class_attr, pre_event=None):
1189
1255
  peaks, _ = find_peaks(status_safe)
1190
1256
  widths, _, left, right = peak_widths(status_safe, peaks, rel_height=1)
1191
1257
  minimum_weight = 0
1192
-
1258
+
1193
1259
  if len(peaks)>0:
1194
1260
  idx = np.argmax(widths)
1195
- peak = peaks[idx]; width = widths[idx];
1261
+ peak = peaks[idx]; width = widths[idx];
1196
1262
  if width >= minimum_weight:
1197
1263
  left = left[idx]; right = right[idx];
1198
1264
  left = timeline_safe[int(left)]; right = timeline_safe[int(right)];
1199
-
1265
+
1200
1266
  df.loc[indices, class_attr] = 0
1201
- df.loc[indices, class_attr.replace('class_','t_')] = left + (right - left)/2.0
1267
+ t0 = left #take onset + (right - left)/2.0
1268
+ df.loc[indices, class_attr.replace('class_','t_')] = t0
1269
+ df.loc[track.loc[track[stat_col].isnull(),class_attr].index, continuous_stat_col] = np.nan
1270
+ df.loc[track.loc[track['FRAME']<t0,class_attr].index, continuous_stat_col] = 0
1271
+ df.loc[track.loc[track['FRAME']>=t0,class_attr].index, continuous_stat_col] = 1
1202
1272
  else:
1203
1273
  df.loc[indices, class_attr] = 1
1204
- df.loc[indices, class_attr.replace('class_','t_')] = -1
1274
+ df.loc[indices, class_attr.replace('class_','t_')] = -1
1275
+ df.loc[indices, continuous_stat_col] = 0
1205
1276
  else:
1206
1277
  df.loc[indices, class_attr] = 1
1207
1278
  df.loc[indices, class_attr.replace('class_','t_')] = -1
1208
-
1209
-
1279
+ df.loc[indices, continuous_stat_col] = 0
1280
+
1281
+ # restate NaN for out of scope timepoints
1282
+ df.loc[df[stat_col].isnull(),continuous_stat_col] = np.nan
1283
+ if 'inst_'+stat_col in list(df.columns):
1284
+ df = df.drop(columns=['inst_'+stat_col])
1285
+ df = df.rename(columns={stat_col: 'inst_'+stat_col})
1286
+ df = df.rename(columns={continuous_stat_col: stat_col})
1210
1287
  print("Classes: ",df.loc[df['FRAME']==0,class_attr].value_counts())
1211
1288
 
1212
1289
  return df
@@ -1286,7 +1363,7 @@ def classify_irreversible_events(data, class_attr, r2_threshold=0.5, percentile_
1286
1363
  indices_pre_detection = track.loc[track['FRAME']<=t_firstdetection,class_attr].index
1287
1364
  track.loc[indices_pre_detection,stat_col] = 0.0
1288
1365
  df.loc[indices_pre_detection,stat_col] = 0.0
1289
-
1366
+
1290
1367
  # The non-NaN part of track (post pre-event)
1291
1368
  track_valid = track.dropna(subset=stat_col, inplace=False)
1292
1369
  status_values = track_valid[stat_col].to_numpy()
@@ -1300,7 +1377,7 @@ def classify_irreversible_events(data, class_attr, r2_threshold=0.5, percentile_
1300
1377
  else:
1301
1378
  # ambiguity, possible transition, use `unique_state` technique after
1302
1379
  df.loc[indices, class_attr] = 2
1303
-
1380
+
1304
1381
  print("Classes after initial pass: ",df.loc[df['FRAME']==0,class_attr].value_counts())
1305
1382
 
1306
1383
  df.loc[df[class_attr]!=2, class_attr.replace('class', 't')] = -1
@@ -1363,7 +1440,7 @@ def classify_unique_states(df, class_attr, percentile=50, pre_event=None):
1363
1440
  assert 'class_'+pre_event in cols,"Pre-event class does not seem to be a valid column in the DataFrame..."
1364
1441
 
1365
1442
  stat_col = class_attr.replace('class','status')
1366
-
1443
+
1367
1444
  for tid, track in df.groupby(sort_cols):
1368
1445
 
1369
1446
  indices = track[class_attr].index
@@ -1488,4 +1565,25 @@ def classify_tracks_from_query(df, event_name, query, irreversible_event=True, u
1488
1565
 
1489
1566
  df = interpret_track_classification(df, class_attr, irreversible_event=irreversible_event, unique_state=unique_state, r2_threshold=r2_threshold, percentile_recovery=percentile_recovery)
1490
1567
 
1491
- return df
1568
+ return df
1569
+
1570
+ def measure_radial_distance_to_center(df, volume, column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
1571
+
1572
+ try:
1573
+ df['radial_distance'] = np.sqrt((df[column_labels['x']] - volume[0] / 2) ** 2 + (df[column_labels['y']] - volume[1] / 2) ** 2)
1574
+ except Exception as e:
1575
+ print(f"{e=}")
1576
+
1577
+ return df
1578
+
1579
+ def center_of_mass_to_abs_coordinates(df):
1580
+
1581
+ center_of_mass_x_cols = [c for c in list(df.columns) if c.endswith('centre_of_mass_x')]
1582
+ center_of_mass_y_cols = [c for c in list(df.columns) if c.endswith('centre_of_mass_y')]
1583
+ for c in center_of_mass_x_cols:
1584
+ df.loc[:,c.replace('_x','_POSITION_X')] = df[c] + df['POSITION_X']
1585
+ for c in center_of_mass_y_cols:
1586
+ df.loc[:,c.replace('_y','_POSITION_Y')] = df[c] + df['POSITION_Y']
1587
+ df = df.drop(columns = center_of_mass_x_cols+center_of_mass_y_cols)
1588
+
1589
+ return df
@@ -7,8 +7,8 @@ import os
7
7
  import json
8
8
  from celldetective.io import auto_load_number_of_frames, load_frames, fix_missing_labels, locate_labels, extract_position_name
9
9
  from celldetective.utils import extract_experiment_channels, ConfigSectionMap, _get_img_num_per_channel, extract_experiment_channels
10
- from celldetective.utils import remove_redundant_features, remove_trajectory_measurements
11
- from celldetective.measure import drop_tonal_features, measure_features, measure_isotropic_intensity
10
+ from celldetective.utils import _remove_invalid_cols, remove_redundant_features, remove_trajectory_measurements, _extract_coordinates_from_features
11
+ from celldetective.measure import drop_tonal_features, measure_features, measure_isotropic_intensity, center_of_mass_to_abs_coordinates, measure_radial_distance_to_center
12
12
  from pathlib import Path, PurePath
13
13
  from glob import glob
14
14
  from tqdm import tqdm
@@ -16,7 +16,6 @@ import numpy as np
16
16
  import pandas as pd
17
17
  from natsort import natsorted
18
18
  from art import tprint
19
- import threading
20
19
  import datetime
21
20
 
22
21
  tprint("Measure")
@@ -68,7 +67,7 @@ instr_path = PurePath(expfolder,Path(f"{instruction_file}"))
68
67
  print('Looking for measurement instruction file...')
69
68
 
70
69
  if os.path.exists(instr_path):
71
-
70
+
72
71
  with open(instr_path, 'r') as f:
73
72
  instructions = json.load(f)
74
73
  print(f"Measurement instruction file successfully loaded...")
@@ -171,14 +170,6 @@ else:
171
170
  features += ['centroid']
172
171
  do_iso_intensities = False
173
172
 
174
- # if 'centroid' not in features:
175
- # features += ['centroid']
176
-
177
- # if (features is not None) and (trajectories is not None):
178
- # features = remove_redundant_features(features,
179
- # trajectories.columns,
180
- # channel_names=channel_names
181
- # )
182
173
 
183
174
  len_movie_auto = auto_load_number_of_frames(file)
184
175
  if len_movie_auto is not None:
@@ -236,7 +227,7 @@ with open(pos + f'log_{mode}.json', 'a') as f:
236
227
 
237
228
  def measure_index(indices):
238
229
 
239
- global column_labels
230
+ #global column_labels
240
231
 
241
232
  for t in tqdm(indices,desc="frame"):
242
233
 
@@ -258,10 +249,7 @@ def measure_index(indices):
258
249
  channels=channel_names, haralick_options=haralick_options, verbose=False,
259
250
  normalisation_list=background_correction, spot_detection=spot_detection)
260
251
  if trajectories is None:
261
- positions_at_t = feature_table[['centroid-1', 'centroid-0', 'class_id']].copy()
262
- positions_at_t['ID'] = np.arange(len(positions_at_t)) # temporary ID for the cells, that will be reset at the end since they are not tracked
263
- positions_at_t.rename(columns={'centroid-1': 'POSITION_X', 'centroid-0': 'POSITION_Y'}, inplace=True)
264
- positions_at_t['FRAME'] = int(t)
252
+ positions_at_t = _extract_coordinates_from_features(feature_table, timepoint=t)
265
253
  column_labels = {'track': "ID", 'time': column_labels['time'], 'x': column_labels['x'],
266
254
  'y': column_labels['y']}
267
255
  feature_table.rename(columns={'centroid-1': 'POSITION_X', 'centroid-0': 'POSITION_Y'}, inplace=True)
@@ -278,25 +266,14 @@ def measure_index(indices):
278
266
  measurements_at_t = positions_at_t.merge(feature_table, how='outer', on='class_id',suffixes=('_delme', ''))
279
267
  measurements_at_t = measurements_at_t[[c for c in measurements_at_t.columns if not c.endswith('_delme')]]
280
268
 
281
- center_of_mass_x_cols = [c for c in list(measurements_at_t.columns) if c.endswith('centre_of_mass_x')]
282
- center_of_mass_y_cols = [c for c in list(measurements_at_t.columns) if c.endswith('centre_of_mass_y')]
283
- for c in center_of_mass_x_cols:
284
- measurements_at_t.loc[:,c.replace('_x','_POSITION_X')] = measurements_at_t[c] + measurements_at_t['POSITION_X']
285
- for c in center_of_mass_y_cols:
286
- measurements_at_t.loc[:,c.replace('_y','_POSITION_Y')] = measurements_at_t[c] + measurements_at_t['POSITION_Y']
287
- measurements_at_t = measurements_at_t.drop(columns = center_of_mass_x_cols+center_of_mass_y_cols)
288
-
289
- try:
290
- measurements_at_t['radial_distance'] = np.sqrt((measurements_at_t[column_labels['x']] - img.shape[0] / 2) ** 2 + (
291
- measurements_at_t[column_labels['y']] - img.shape[1] / 2) ** 2)
292
- except Exception as e:
293
- print(f"{e=}")
269
+ measurements_at_t = center_of_mass_to_abs_coordinates(measurements_at_t)
270
+ measurements_at_t = measure_radial_distance_to_center(measurements_at_t, volume=img.shape, column_labels=column_labels)
294
271
 
295
272
  if measurements_at_t is not None:
296
273
  measurements_at_t[column_labels['time']] = t
297
274
  timestep_dataframes.append(measurements_at_t)
298
275
 
299
- return
276
+ return
300
277
 
301
278
 
302
279
  print(f"Starting the measurements with {n_threads} thread(s)...")
@@ -327,12 +304,10 @@ if len(timestep_dataframes)>0:
327
304
  df = df.dropna(subset=[column_labels['track']])
328
305
  else:
329
306
  df['ID'] = np.arange(len(df))
307
+ df = df.sort_values(by=[column_labels['time'], 'ID'])
330
308
 
331
309
  df = df.reset_index(drop=True)
332
-
333
- invalid_cols = [c for c in list(df.columns) if c.startswith('Unnamed')]
334
- if len(invalid_cols)>0:
335
- df = df.drop(invalid_cols, axis=1)
310
+ df = _remove_invalid_cols(df)
336
311
 
337
312
  df.to_csv(pos+os.sep.join(["output", "tables", table_name]), index=False)
338
313
  print(f'Measurement table successfully exported in {os.sep.join(["output", "tables"])}...')
@@ -6,21 +6,17 @@ import argparse
6
6
  import datetime
7
7
  import os
8
8
  import json
9
- from stardist.models import StarDist2D
10
- from cellpose.models import CellposeModel
11
- from celldetective.io import locate_segmentation_model, auto_load_number_of_frames, load_frames, extract_position_name
12
- from celldetective.utils import interpolate_nan, _estimate_scale_factor, _extract_channel_indices_from_config, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel
9
+ from celldetective.io import locate_segmentation_model, auto_load_number_of_frames, extract_position_name, _load_frames_to_segment, _check_label_dims
10
+ from celldetective.utils import _prep_stardist_model, _prep_cellpose_model, _rescale_labels, _segment_image_with_stardist_model,_segment_image_with_cellpose_model,_get_normalize_kwargs_from_config, _estimate_scale_factor, _extract_channel_indices_from_config, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel
13
11
  from pathlib import Path, PurePath
14
12
  from glob import glob
15
13
  from shutil import rmtree
16
14
  from tqdm import tqdm
17
15
  import numpy as np
18
- from skimage.transform import resize
19
16
  from csbdeep.io import save_tiff_imagej_compatible
20
17
  import gc
21
18
  from art import tprint
22
- from scipy.ndimage import zoom
23
-
19
+ import concurrent.futures
24
20
 
25
21
  tprint("Segment")
26
22
 
@@ -91,9 +87,7 @@ print(f'Required channels: {required_channels} located at channel indices {chann
91
87
  required_spatial_calibration = input_config['spatial_calibration']
92
88
  print(f'Spatial calibration expected by the model: {required_spatial_calibration}...')
93
89
 
94
- normalization_percentile = input_config['normalization_percentile']
95
- normalization_clip = input_config['normalization_clip']
96
- normalization_values = input_config['normalization_values']
90
+ normalize_kwargs = _get_normalize_kwargs_from_config(input_config)
97
91
 
98
92
  model_type = input_config['model_type']
99
93
 
@@ -142,81 +136,40 @@ with open(pos+f'log_{mode}.json', 'a') as f:
142
136
  # Loop over all frames and segment
143
137
  def segment_index(indices):
144
138
 
145
- global scale
146
-
147
139
  if model_type=='stardist':
148
- model = StarDist2D(None, name=modelname, basedir=Path(model_complete_path).parent)
149
- model.config.use_gpu = use_gpu
150
- model.use_gpu = use_gpu
151
- print(f"StarDist model {modelname} successfully loaded.")
152
- scale_model = scale
140
+ model, scale_model = _prep_stardist_model(modelname, Path(model_complete_path).parent, use_gpu=use_gpu, scale=scale)
153
141
 
154
142
  elif model_type=='cellpose':
155
-
156
- import torch
157
- if not use_gpu:
158
- device = torch.device("cpu")
159
- else:
160
- device = torch.device("cuda")
161
-
162
- model = CellposeModel(gpu=use_gpu, device=device, pretrained_model=model_complete_path+modelname, model_type=None, nchan=len(required_channels)) #diam_mean=30.0,
163
- if scale is None:
164
- scale_model = model.diam_mean / model.diam_labels
165
- else:
166
- scale_model = scale * model.diam_mean / model.diam_labels
167
- print(f"Diam mean: {model.diam_mean}; Diam labels: {model.diam_labels}; Final rescaling: {scale_model}...")
168
- print(f'Cellpose model {modelname} successfully loaded.')
143
+ model, scale_model = _prep_cellpose_model(modelname, model_complete_path, use_gpu=use_gpu, n_channels=len(required_channels), scale=scale)
169
144
 
170
145
  for t in tqdm(indices,desc="frame"):
171
-
172
- # Load channels at time t
173
- values = []
174
- percentiles = []
175
- for k in range(len(normalization_percentile)):
176
- if normalization_percentile[k]:
177
- percentiles.append(normalization_values[k])
178
- values.append(None)
179
- else:
180
- percentiles.append(None)
181
- values.append(normalization_values[k])
182
-
183
- f = load_frames(img_num_channels[:,t], file, scale=scale_model, normalize_input=True, normalize_kwargs={"percentiles": percentiles, 'values': values, 'clip': normalization_clip})
184
- f = np.moveaxis([interpolate_nan(f[:,:,c].copy()) for c in range(f.shape[-1])],0,-1)
185
-
186
- if np.any(img_num_channels[:,t]==-1):
187
- f[:,:,np.where(img_num_channels[:,t]==-1)[0]] = 0.
188
146
 
189
- if model_type=="stardist":
190
- Y_pred, details = model.predict_instances(f, n_tiles=model._guess_n_tiles(f), show_tile_progress=False, verbose=False)
191
- Y_pred = Y_pred.astype(np.uint16)
147
+ f = _load_frames_to_segment(file, img_num_channels[:,t], scale_model=scale_model, normalize_kwargs=normalize_kwargs)
192
148
 
149
+ if model_type=="stardist":
150
+ Y_pred = _segment_image_with_stardist_model(f, model=model, return_details=False)
193
151
  elif model_type=="cellpose":
194
-
195
- img = np.moveaxis(f, -1, 0)
196
- Y_pred, _, _ = model.eval(img, diameter = diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, channels=None, normalize=False)
197
- Y_pred = Y_pred.astype(np.uint16)
152
+ Y_pred = _segment_image_with_cellpose_model(f, model=model, diameter=diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold)
198
153
 
199
154
  if scale is not None:
200
- Y_pred = zoom(Y_pred, [1./scale_model,1./scale_model],order=0)
155
+ Y_pred = _rescale_labels(Y_pred, scale_model=scale_model)
201
156
 
202
- template = load_frames(0,file,scale=1,normalize_input=False)
203
- if Y_pred.shape != template.shape[:2]:
204
- Y_pred = resize(Y_pred, template.shape[:2], order=0)
157
+ Y_pred = _check_label_dims(Y_pred, file)
205
158
 
206
159
  save_tiff_imagej_compatible(pos+os.sep.join([label_folder,f"{str(t).zfill(4)}.tif"]), Y_pred, axes='YX')
207
160
 
208
161
  del f;
209
- del template;
210
162
  del Y_pred;
211
163
  gc.collect()
212
164
 
165
+ del model;
166
+ gc.collect()
167
+
213
168
  return
214
169
 
215
170
 
216
171
  print(f"Starting the segmentation with {n_threads} thread(s) and GPU={use_gpu}...")
217
172
 
218
- import concurrent.futures
219
-
220
173
  # Multithreading
221
174
  indices = list(range(img_num_channels.shape[1]))
222
175
  chunks = np.array_split(indices, n_threads)
@@ -16,6 +16,7 @@ import numpy as np
16
16
  from csbdeep.io import save_tiff_imagej_compatible
17
17
  import gc
18
18
  from art import tprint
19
+ import concurrent.futures
19
20
 
20
21
  tprint("Segment")
21
22
 
@@ -124,8 +125,6 @@ def segment_index(indices):
124
125
 
125
126
  print(f"Starting the segmentation with {n_threads} thread(s)...")
126
127
 
127
- import concurrent.futures
128
-
129
128
  # Multithreading
130
129
  indices = list(range(img_num_channels.shape[1]))
131
130
  chunks = np.array_split(indices, n_threads)
@@ -6,8 +6,8 @@ import argparse
6
6
  import datetime
7
7
  import os
8
8
  import json
9
- from celldetective.io import auto_load_number_of_frames, load_frames, interpret_tracking_configuration, extract_position_name
10
- from celldetective.utils import extract_experiment_channels, ConfigSectionMap, _get_img_num_per_channel, extract_experiment_channels
9
+ from celldetective.io import auto_load_number_of_frames, interpret_tracking_configuration, extract_position_name
10
+ from celldetective.utils import _mask_intensity_measurements, extract_experiment_channels, ConfigSectionMap, _get_img_num_per_channel, extract_experiment_channels
11
11
  from celldetective.measure import drop_tonal_features, measure_features
12
12
  from celldetective.tracking import track
13
13
  from pathlib import Path, PurePath
@@ -19,7 +19,8 @@ import gc
19
19
  import os
20
20
  from natsort import natsorted
21
21
  from art import tprint
22
- from tifffile import imread
22
+ import concurrent.futures
23
+
23
24
 
24
25
  tprint("Track")
25
26
 
@@ -173,35 +174,39 @@ if not btrack_option:
173
174
 
174
175
  def measure_index(indices):
175
176
 
177
+ props = []
178
+
176
179
  for t in tqdm(indices,desc="frame"):
177
180
 
178
181
  # Load channels at time t
179
- img = load_frames(img_num_channels[:,t], file, scale=None, normalize_input=False)
180
- lbl = imread(label_path[t])
182
+ img = _load_frames_to_measure(file, indices=img_num_channels[:,t])
183
+ lbl = locate_labels(pos, population=mode, frames=t)
184
+ if lbl is None:
185
+ continue
181
186
 
182
187
  df_props = measure_features(img, lbl, features = features+['centroid'], border_dist=None,
183
188
  channels=channel_names, haralick_options=haralick_options, verbose=False,
184
189
  )
185
190
  df_props.rename(columns={'centroid-1': 'x', 'centroid-0': 'y'},inplace=True)
186
191
  df_props['t'] = int(t)
187
- timestep_dataframes.append(df_props)
188
- return
189
192
 
193
+ props.append(df_props)
190
194
 
195
+ return props
191
196
 
192
197
  print(f"Measuring features with {n_threads} thread(s)...")
193
198
 
194
- import concurrent.futures
195
-
196
199
  # Multithreading
197
200
  indices = list(range(img_num_channels.shape[1]))
198
201
  chunks = np.array_split(indices, n_threads)
199
202
 
203
+ timestep_dataframes = []
200
204
  with concurrent.futures.ThreadPoolExecutor() as executor:
201
205
  results = executor.map(measure_index, chunks)
202
206
  try:
203
207
  for i,return_value in enumerate(results):
204
- print(f"Thread {i} output check: ",return_value)
208
+ print(f"Thread {i} completed...")
209
+ timestep_dataframes.extend(return_value)
205
210
  except Exception as e:
206
211
  print("Exception: ", e)
207
212
 
@@ -210,15 +215,7 @@ print('Features successfully measured...')
210
215
  df = pd.concat(timestep_dataframes)
211
216
  df.reset_index(inplace=True, drop=True)
212
217
 
213
- if mask_channels is not None:
214
- cols_to_drop = []
215
- for mc in mask_channels:
216
- columns = df.columns
217
- col_contains = [mc in c for c in columns]
218
- to_remove = np.array(columns)[np.array(col_contains)]
219
- cols_to_drop.extend(to_remove)
220
- if len(cols_to_drop)>0:
221
- df = df.drop(cols_to_drop, axis=1)
218
+ df = _mask_intensity_measurements(df, mask_channels)
222
219
 
223
220
  # do tracking
224
221
  if btrack_option:
@@ -8,12 +8,9 @@ from .utils import _estimate_scale_factor, _extract_channel_indices
8
8
  from pathlib import Path
9
9
  from tqdm import tqdm
10
10
  import numpy as np
11
- from stardist.models import StarDist2D
12
- from cellpose.models import CellposeModel
13
- from skimage.transform import resize
14
- from celldetective.io import _view_on_napari, locate_labels, locate_stack, _view_on_napari
11
+ from celldetective.io import _view_on_napari, locate_labels, locate_stack, _view_on_napari, _check_label_dims
15
12
  from celldetective.filters import * #rework this to give a name
16
- from celldetective.utils import rename_intensity_column, mask_edges, estimate_unreliable_edge
13
+ from celldetective.utils import interpolate_nan_multichannel,_rearrange_multichannel_frame, _fix_no_contrast, zoom_multiframes, _rescale_labels, rename_intensity_column, mask_edges, _prep_stardist_model, _prep_cellpose_model, estimate_unreliable_edge,_get_normalize_kwargs_from_config, _segment_image_with_stardist_model, _segment_image_with_cellpose_model
17
14
  from stardist import fill_label_holes
18
15
  import scipy.ndimage as ndi
19
16
  from skimage.segmentation import watershed
@@ -99,9 +96,7 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
99
96
  required_spatial_calibration = input_config['spatial_calibration']
100
97
  model_type = input_config['model_type']
101
98
 
102
- normalization_percentile = input_config['normalization_percentile']
103
- normalization_clip = input_config['normalization_clip']
104
- normalization_values = input_config['normalization_values']
99
+ normalize_kwargs = _get_normalize_kwargs_from_config(input_config)
105
100
 
106
101
  if model_type=='cellpose':
107
102
  diameter = input_config['diameter']
@@ -116,28 +111,10 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
116
111
  print(f"{spatial_calibration=} {required_spatial_calibration=} Scale = {scale}...")
117
112
 
118
113
  if model_type=='stardist':
119
-
120
- model = StarDist2D(None, name=model_name, basedir=Path(model_path).parent)
121
- model.config.use_gpu = use_gpu
122
- model.use_gpu = use_gpu
123
- print(f"StarDist model {model_name} successfully loaded.")
124
- scale_model = scale
114
+ model, scale_model = _prep_stardist_model(model_name, Path(model_path).parent, use_gpu=use_gpu, scale=scale)
125
115
 
126
116
  elif model_type=='cellpose':
127
-
128
- import torch
129
- if not use_gpu:
130
- device = torch.device("cpu")
131
- else:
132
- device = torch.device("cuda")
133
-
134
- model = CellposeModel(gpu=use_gpu, device=device, pretrained_model=model_path+model_path.split('/')[-2], model_type=None, nchan=len(required_channels))
135
- if scale is None:
136
- scale_model = model.diam_mean / model.diam_labels
137
- else:
138
- scale_model = scale * model.diam_mean / model.diam_labels
139
- print(f"Diam mean: {model.diam_mean}; Diam labels: {model.diam_labels}; Final rescaling: {scale_model}...")
140
- print(f'Cellpose model {model_name} successfully loaded.')
117
+ model, scale_model = _prep_cellpose_model(model_path.split('/')[-2], model_path, use_gpu=use_gpu, n_channels=len(required_channels), scale=scale)
141
118
 
142
119
  labels = []
143
120
 
@@ -149,11 +126,7 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
149
126
  channel_indices[channel_indices==None] = 0
150
127
 
151
128
  frame = stack[t]
152
- #frame = stack[t,:,:,channel_indices.astype(int)].astype(float)
153
- if frame.ndim==2:
154
- frame = frame[:,:,np.newaxis]
155
- if frame.ndim==3 and np.array(frame.shape).argmin()==0:
156
- frame = np.moveaxis(frame,0,-1)
129
+ frame = _rearrange_multichannel_frame(frame).astype(float)
157
130
 
158
131
  frame_to_segment = np.zeros((frame.shape[0], frame.shape[1], len(required_channels))).astype(float)
159
132
  for ch in channel_intersection:
@@ -162,45 +135,25 @@ def segment(stack, model_name, channels=None, spatial_calibration=None, view_on_
162
135
  frame = frame_to_segment
163
136
  template = frame.copy()
164
137
 
165
- values = []
166
- percentiles = []
167
- for k in range(len(normalization_percentile)):
168
- if normalization_percentile[k]:
169
- percentiles.append(normalization_values[k])
170
- values.append(None)
171
- else:
172
- percentiles.append(None)
173
- values.append(normalization_values[k])
174
-
175
- frame = normalize_multichannel(frame, **{"percentiles": percentiles, 'values': values, 'clip': normalization_clip})
138
+ frame = normalize_multichannel(frame, **normalize_kwargs)
176
139
 
177
140
  if scale_model is not None:
178
- frame = [zoom(frame[:,:,c].copy(), [scale_model,scale_model], order=3, prefilter=False) for c in range(frame.shape[-1])]
179
- frame = np.moveaxis(frame,0,-1)
180
-
181
- for k in range(frame.shape[2]):
182
- unique_values = np.unique(frame[:,:,k])
183
- if len(unique_values)==1:
184
- frame[0,0,k] += 1
141
+ frame = zoom_multiframes(frame, scale_model)
185
142
 
186
- frame = np.moveaxis([interpolate_nan(frame[:,:,c].copy()) for c in range(frame.shape[-1])],0,-1)
143
+ frame = _fix_no_contrast(frame)
144
+ frame = interpolate_nan_multichannel(frame)
187
145
  frame[:,:,none_channel_indices] = 0.
188
146
 
189
147
  if model_type=="stardist":
190
-
191
- Y_pred, details = model.predict_instances(frame, n_tiles=model._guess_n_tiles(frame), show_tile_progress=False, verbose=False)
192
- Y_pred = Y_pred.astype(np.uint16)
148
+ Y_pred = _segment_image_with_stardist_model(frame, model=model, return_details=False)
193
149
 
194
150
  elif model_type=="cellpose":
195
-
196
- img = np.moveaxis(frame, -1, 0)
197
- Y_pred, _, _ = model.eval(img, diameter = diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, channels=None, normalize=False)
198
- Y_pred = Y_pred.astype(np.uint16)
151
+ Y_pred = _segment_image_with_cellpose_model(frame, model=model, diameter=diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold)
199
152
 
200
153
  if Y_pred.shape != stack[0].shape[:2]:
201
- Y_pred = zoom(Y_pred, [1./scale_model,1./scale_model],order=0)
202
- if Y_pred.shape != template.shape[:2]:
203
- Y_pred = resize(Y_pred, template.shape[:2], order=0)
154
+ Y_pred = _rescale_labels(Y_pred, scale_model)
155
+
156
+ Y_pred = _check_label_dims(Y_pred, template=template)
204
157
 
205
158
  labels.append(Y_pred)
206
159
 
@@ -315,6 +268,7 @@ def segment_frame_from_thresholds(frame, target_channel=0, thresholds=None, equa
315
268
  """
316
269
 
317
270
  img = frame[:,:,target_channel]
271
+ img = interpolate_nan(img)
318
272
  if equalize_reference is not None:
319
273
  img = match_histograms(img, equalize_reference)
320
274
  img_mc = frame.copy()