celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +403 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/downloader.py +137 -0
  81. celldetective/processes/measure_cells.py +565 -0
  82. celldetective/processes/segment_cells.py +760 -0
  83. celldetective/processes/track_cells.py +435 -0
  84. celldetective/processes/train_segmentation_model.py +694 -0
  85. celldetective/processes/train_signal_model.py +265 -0
  86. celldetective/processes/unified_process.py +292 -0
  87. celldetective/regionprops/_regionprops.py +358 -317
  88. celldetective/relative_measurements.py +987 -710
  89. celldetective/scripts/measure_cells.py +313 -212
  90. celldetective/scripts/measure_relative.py +90 -46
  91. celldetective/scripts/segment_cells.py +165 -104
  92. celldetective/scripts/segment_cells_thresholds.py +96 -68
  93. celldetective/scripts/track_cells.py +198 -149
  94. celldetective/scripts/train_segmentation_model.py +324 -201
  95. celldetective/scripts/train_signal_model.py +87 -45
  96. celldetective/segmentation.py +844 -749
  97. celldetective/signals.py +3514 -2861
  98. celldetective/tracking.py +30 -15
  99. celldetective/utils/__init__.py +0 -0
  100. celldetective/utils/cellpose_utils/__init__.py +133 -0
  101. celldetective/utils/color_mappings.py +42 -0
  102. celldetective/utils/data_cleaning.py +630 -0
  103. celldetective/utils/data_loaders.py +450 -0
  104. celldetective/utils/dataset_helpers.py +207 -0
  105. celldetective/utils/downloaders.py +235 -0
  106. celldetective/utils/event_detection/__init__.py +8 -0
  107. celldetective/utils/experiment.py +1782 -0
  108. celldetective/utils/image_augmenters.py +308 -0
  109. celldetective/utils/image_cleaning.py +74 -0
  110. celldetective/utils/image_loaders.py +926 -0
  111. celldetective/utils/image_transforms.py +335 -0
  112. celldetective/utils/io.py +62 -0
  113. celldetective/utils/mask_cleaning.py +348 -0
  114. celldetective/utils/mask_transforms.py +5 -0
  115. celldetective/utils/masks.py +184 -0
  116. celldetective/utils/maths.py +351 -0
  117. celldetective/utils/model_getters.py +325 -0
  118. celldetective/utils/model_loaders.py +296 -0
  119. celldetective/utils/normalization.py +380 -0
  120. celldetective/utils/parsing.py +465 -0
  121. celldetective/utils/plots/__init__.py +0 -0
  122. celldetective/utils/plots/regression.py +53 -0
  123. celldetective/utils/resources.py +34 -0
  124. celldetective/utils/stardist_utils/__init__.py +104 -0
  125. celldetective/utils/stats.py +90 -0
  126. celldetective/utils/types.py +21 -0
  127. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
  128. celldetective-1.5.0b1.dist-info/RECORD +187 -0
  129. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
  130. tests/gui/test_new_project.py +129 -117
  131. tests/gui/test_project.py +127 -79
  132. tests/test_filters.py +39 -15
  133. tests/test_notebooks.py +8 -0
  134. tests/test_tracking.py +232 -13
  135. tests/test_utils.py +123 -77
  136. celldetective/gui/base_components.py +0 -23
  137. celldetective/gui/layouts.py +0 -1602
  138. celldetective/gui/processes/compute_neighborhood.py +0 -594
  139. celldetective/gui/processes/downloader.py +0 -111
  140. celldetective/gui/processes/measure_cells.py +0 -360
  141. celldetective/gui/processes/segment_cells.py +0 -499
  142. celldetective/gui/processes/track_cells.py +0 -303
  143. celldetective/gui/processes/train_segmentation_model.py +0 -270
  144. celldetective/gui/processes/train_signal_model.py +0 -108
  145. celldetective/gui/table_ops/merge_groups.py +0 -118
  146. celldetective/gui/viewers.py +0 -1354
  147. celldetective/io.py +0 -3663
  148. celldetective/utils.py +0 -3108
  149. celldetective-1.4.2.dist-info/RECORD +0 -123
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
  152. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,926 @@
1
+ import gc
2
+ import json
3
+ import os
4
+ from glob import glob
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ from celldetective.utils.io import save_tiff_imagej_compatible
9
+ from imageio import v2 as imageio
10
+ from natsort import natsorted
11
+ from tifffile import imread, TiffFile
12
+
13
+ from celldetective.utils.image_cleaning import (
14
+ _fix_no_contrast,
15
+ interpolate_nan_multichannel,
16
+ )
17
+ from celldetective.utils.normalization import normalize_multichannel
18
+ from celldetective import get_logger
19
+
20
+ import logging
21
+ import warnings
22
+
23
+ logger = get_logger(__name__)
24
+
25
+ # Suppress tifffile warnings about missing files in MMStack
26
+ logging.getLogger("tifffile").setLevel(logging.ERROR)
27
+ warnings.filterwarnings("ignore", message=".*MMStack series is missing files.*")
28
+
29
+
30
+ def locate_stack(position, prefix="Aligned"):
31
+ """
32
+
33
+ Locate and load a stack of images.
34
+
35
+ Parameters
36
+ ----------
37
+ position : str
38
+ The position folder within the well where the stack is located.
39
+ prefix : str, optional
40
+ The prefix used to identify the stack. The default is 'Aligned'.
41
+
42
+ Returns
43
+ -------
44
+ stack : ndarray
45
+ The loaded stack as a NumPy array.
46
+
47
+ Raises
48
+ ------
49
+ AssertionError
50
+ If no stack with the specified prefix is found.
51
+
52
+ Notes
53
+ -----
54
+ This function locates and loads a stack of images based on the specified position and prefix.
55
+ It assumes that the stack is stored in a directory named 'movie' within the specified position.
56
+ The function loads the stack as a NumPy array and performs shape manipulation to have the channels
57
+ at the end.
58
+
59
+ Examples
60
+ --------
61
+ >>> stack = locate_stack(position, prefix='Aligned')
62
+ # Locate and load a stack of images for further processing.
63
+
64
+ """
65
+
66
+ if not position.endswith(os.sep):
67
+ position += os.sep
68
+
69
+ stack_path = glob(position + os.sep.join(["movie", f"{prefix}*.tif"]))
70
+ if not stack_path:
71
+ raise FileNotFoundError(f"No movie with prefix {prefix} found...")
72
+
73
+ stack = imread(stack_path[0].replace("\\", "/"))
74
+ stack_length = auto_load_number_of_frames(stack_path[0])
75
+
76
+ if stack.ndim == 4:
77
+ stack = np.moveaxis(stack, 1, -1)
78
+ elif stack.ndim == 3:
79
+ if min(stack.shape) != stack_length:
80
+ channel_axis = np.argmin(stack.shape)
81
+ if channel_axis != (stack.ndim - 1):
82
+ stack = np.moveaxis(stack, channel_axis, -1)
83
+ stack = stack[np.newaxis, :, :, :]
84
+ else:
85
+ stack = stack[:, :, :, np.newaxis]
86
+ elif stack.ndim == 2:
87
+ stack = stack[np.newaxis, :, :, np.newaxis]
88
+
89
+ return stack
90
+
91
+
92
+ def locate_labels(position, population="target", frames=None):
93
+ """
94
+ Locate and load label images for a given position and population in an experiment.
95
+
96
+ This function retrieves and optionally loads labeled images (e.g., targets or effectors)
97
+ for a specified position in an experiment. It supports loading all frames, a specific
98
+ frame, or a list of frames.
99
+
100
+ Parameters
101
+ ----------
102
+ position : str
103
+ Path to the position directory containing label images.
104
+ population : str, optional
105
+ The population to load labels for. Options are `'target'` (or `'targets'`) and
106
+ `'effector'` (or `'effectors'`). Default is `'target'`.
107
+ frames : int, list of int, numpy.ndarray, or None, optional
108
+ Specifies which frames to load:
109
+ - `None`: Load all frames (default).
110
+ - `int`: Load a single frame, identified by its index.
111
+ - `list` or `numpy.ndarray`: Load multiple specific frames.
112
+
113
+ Returns
114
+ -------
115
+ numpy.ndarray or list of numpy.ndarray
116
+ If `frames` is `None` or a single integer, returns a NumPy array of the corresponding
117
+ labels. If `frames` is a list or array, returns a list of NumPy arrays for each frame.
118
+ If a frame is not found, `None` is returned for that frame.
119
+
120
+ Notes
121
+ -----
122
+ - The function assumes label images are stored in subdirectories named `"labels_targets"`
123
+ or `"labels_effectors"`, with filenames formatted as `####.tif` (e.g., `0001.tif`).
124
+ - Frame indices are zero-padded to four digits for matching.
125
+ - If `frames` is invalid or a frame is not found, `None` is returned for that frame.
126
+
127
+ Examples
128
+ --------
129
+ Load all label images for a position:
130
+
131
+ >>> labels = locate_labels("/path/to/position", population="target")
132
+
133
+ Load a single frame (frame index 3):
134
+
135
+ >>> label = locate_labels("/path/to/position", population="effector", frames=3)
136
+
137
+ Load multiple specific frames:
138
+
139
+ >>> labels = locate_labels("/path/to/position", population="target", frames=[0, 1, 2])
140
+
141
+ """
142
+
143
+ if not position.endswith(os.sep):
144
+ position += os.sep
145
+
146
+ if population.lower() == "target" or population.lower() == "targets":
147
+ label_path = natsorted(
148
+ glob(position + os.sep.join(["labels_targets", "*.tif"]))
149
+ )
150
+ elif population.lower() == "effector" or population.lower() == "effectors":
151
+ label_path = natsorted(
152
+ glob(position + os.sep.join(["labels_effectors", "*.tif"]))
153
+ )
154
+ else:
155
+ label_path = natsorted(
156
+ glob(position + os.sep.join([f"labels_{population}", "*.tif"]))
157
+ )
158
+
159
+ label_names = [os.path.split(lbl)[-1] for lbl in label_path]
160
+
161
+ if frames is None:
162
+
163
+ labels = np.array([imread(i.replace("\\", "/")) for i in label_path])
164
+
165
+ elif isinstance(frames, (int, float, np.int_)):
166
+
167
+ tzfill = str(int(frames)).zfill(4)
168
+ try:
169
+ idx = label_names.index(f"{tzfill}.tif")
170
+ except:
171
+ idx = -1
172
+
173
+ if idx == -1:
174
+ labels = None
175
+ else:
176
+ labels = np.array(imread(label_path[idx].replace("\\", "/")))
177
+
178
+ elif isinstance(frames, (list, np.ndarray)):
179
+ labels = []
180
+ for f in frames:
181
+ tzfill = str(int(f)).zfill(4)
182
+ try:
183
+ idx = label_names.index(f"{tzfill}.tif")
184
+ except:
185
+ idx = -1
186
+
187
+ if idx == -1:
188
+ labels.append(None)
189
+ else:
190
+ labels.append(np.array(imread(label_path[idx].replace("\\", "/"))))
191
+ else:
192
+ print("Frames argument must be None, int or list...")
193
+
194
+ return labels
195
+
196
+
197
+ def locate_stack_and_labels(position, prefix="Aligned", population="target"):
198
+ """
199
+
200
+ Locate and load the stack and corresponding segmentation labels.
201
+
202
+ Parameters
203
+ ----------
204
+ position : str
205
+ The position or directory path where the stack and labels are located.
206
+ prefix : str, optional
207
+ The prefix used to identify the stack. The default is 'Aligned'.
208
+ population : str, optional
209
+ The population for which the segmentation must be located. The default is 'target'.
210
+
211
+ Returns
212
+ -------
213
+ stack : ndarray
214
+ The loaded stack as a NumPy array.
215
+ labels : ndarray
216
+ The loaded segmentation labels as a NumPy array.
217
+
218
+ Raises
219
+ ------
220
+ AssertionError
221
+ If no stack with the specified prefix is found or if the shape of the stack and labels do not match.
222
+
223
+ Notes
224
+ -----
225
+ This function locates the stack and corresponding segmentation labels based on the specified position and population.
226
+ It assumes that the stack and labels are stored in separate directories: 'movie' for the stack and 'labels' or 'labels_effectors' for the labels.
227
+ The function loads the stack and labels as NumPy arrays and performs shape validation.
228
+
229
+ Examples
230
+ --------
231
+ >>> stack, labels = locate_stack_and_labels(position, prefix='Aligned', population="target")
232
+ # Locate and load the stack and segmentation labels for further processing.
233
+
234
+ """
235
+
236
+ position = position.replace("\\", "/")
237
+ labels = locate_labels(position, population=population)
238
+ stack = locate_stack(position, prefix=prefix)
239
+ if len(labels) < len(stack):
240
+ fix_missing_labels(position, population=population, prefix=prefix)
241
+ labels = locate_labels(position, population=population)
242
+ assert len(stack) == len(
243
+ labels
244
+ ), f"The shape of the stack {stack.shape} does not match with the shape of the labels {labels.shape}"
245
+
246
+ return stack, labels
247
+
248
+
249
+ def auto_load_number_of_frames(stack_path):
250
+ """
251
+ Automatically determine the number of frames in a TIFF image stack.
252
+
253
+ This function extracts the number of frames (time slices) from the metadata of a TIFF file
254
+ or infers it from the stack dimensions when metadata is unavailable. It is robust to
255
+ variations in metadata structure and handles multi-channel images.
256
+
257
+ Parameters
258
+ ----------
259
+ stack_path : str
260
+ Path to the TIFF image stack file.
261
+
262
+ Returns
263
+ -------
264
+ int or None
265
+ The number of frames in the image stack. Returns `None` if the path is `None`
266
+ or the frame count cannot be determined.
267
+
268
+ Notes
269
+ -----
270
+ - The function attempts to extract the `frames` or `slices` attributes from the
271
+ TIFF metadata, specifically the `ImageDescription` tag.
272
+ - If metadata extraction fails, the function reads the image stack and infers
273
+ the number of frames based on the stack dimensions.
274
+ - Multi-channel stacks are handled by assuming the number of channels is specified
275
+ in the metadata under the `channels` attribute.
276
+
277
+ Examples
278
+ --------
279
+ Automatically detect the number of frames in a TIFF stack:
280
+
281
+ >>> frames = auto_load_number_of_frames("experiment_stack.tif")
282
+ Automatically detected stack length: 120...
283
+
284
+ Handle a single-frame TIFF:
285
+
286
+ >>> frames = auto_load_number_of_frames("single_frame_stack.tif")
287
+ Automatically detected stack length: 1...
288
+
289
+ Handle invalid or missing paths gracefully:
290
+
291
+ >>> frames = auto_load_number_of_frames("stack.tif")
292
+ >>> print(frames)
293
+ None
294
+
295
+ """
296
+
297
+ if stack_path is None:
298
+ return None
299
+
300
+ stack_path = stack_path.replace("\\", "/")
301
+ n_channels = 1
302
+
303
+ with TiffFile(stack_path) as tif:
304
+ try:
305
+ tif_tags = {}
306
+ for tag in tif.pages[0].tags.values():
307
+ name, value = tag.name, tag.value
308
+ tif_tags[name] = value
309
+ img_desc = tif_tags["ImageDescription"]
310
+ attr = img_desc.split("\n")
311
+ n_channels = int(
312
+ attr[np.argmax([s.startswith("channels") for s in attr])].split("=")[-1]
313
+ )
314
+ except Exception as e:
315
+ pass
316
+ try:
317
+ # Try nframes
318
+ nslices = int(
319
+ attr[np.argmax([s.startswith("frames") for s in attr])].split("=")[-1]
320
+ )
321
+ if nslices > 1:
322
+ len_movie = nslices
323
+ else:
324
+ break_the_code()
325
+ except:
326
+ try:
327
+ # try nslices
328
+ frames = int(
329
+ attr[np.argmax([s.startswith("slices") for s in attr])].split("=")[
330
+ -1
331
+ ]
332
+ )
333
+ len_movie = frames
334
+ except:
335
+ pass
336
+
337
+ try:
338
+ del tif
339
+ del tif_tags
340
+ del img_desc
341
+ except:
342
+ pass
343
+
344
+ if "len_movie" not in locals():
345
+ stack = imread(stack_path)
346
+ len_movie = len(stack)
347
+ if len_movie == n_channels and stack.ndim == 3:
348
+ len_movie = 1
349
+ if stack.ndim == 2:
350
+ len_movie = 1
351
+ del stack
352
+ gc.collect()
353
+
354
+ logger.info(f"Automatically detected stack length: {len_movie}...")
355
+
356
+ return len_movie if "len_movie" in locals() else None
357
+
358
+
359
+ def _load_frames_to_segment(file, indices, scale_model=None, normalize_kwargs=None):
360
+
361
+ frames = load_frames(
362
+ indices,
363
+ file,
364
+ scale=scale_model,
365
+ normalize_input=True,
366
+ normalize_kwargs=normalize_kwargs,
367
+ )
368
+ frames = interpolate_nan_multichannel(frames)
369
+
370
+ if np.any(indices == -1):
371
+ frames[:, :, np.where(indices == -1)[0]] = 0.0
372
+
373
+ return frames
374
+
375
+
376
+ def _load_frames_to_measure(file, indices):
377
+ return load_frames(indices, file, scale=None, normalize_input=False)
378
+
379
+
380
+ def load_frames(
381
+ img_nums,
382
+ stack_path,
383
+ scale=None,
384
+ normalize_input=True,
385
+ dtype=np.float64,
386
+ normalize_kwargs={"percentiles": (0.0, 99.99)},
387
+ ):
388
+ """
389
+ Loads and optionally normalizes and rescales specified frames from a stack located at a given path.
390
+
391
+ This function reads specified frames from a stack file, applying systematic adjustments to ensure
392
+ the channel axis is last. It supports optional normalization of the input frames and rescaling. An
393
+ artificial pixel modification is applied to frames with uniform values to prevent errors during
394
+ normalization.
395
+
396
+ Parameters
397
+ ----------
398
+ img_nums : int or list of int
399
+ The index (or indices) of the image frame(s) to load from the stack.
400
+ stack_path : str
401
+ The file path to the stack from which frames are to be loaded.
402
+ scale : float, optional
403
+ The scaling factor to apply to the frames. If None, no scaling is applied (default is None).
404
+ normalize_input : bool, optional
405
+ Whether to normalize the loaded frames. If True, normalization is applied according to
406
+ `normalize_kwargs` (default is True).
407
+ dtype : data-type, optional
408
+ The desired data-type for the output frames (default is float).
409
+ normalize_kwargs : dict, optional
410
+ Keyword arguments to pass to the normalization function (default is {"percentiles": (0., 99.99)}).
411
+
412
+ Returns
413
+ -------
414
+ ndarray or None
415
+ The loaded, and possibly normalized and rescaled, frames as a NumPy array. Returns None if there
416
+ is an error in loading the frames.
417
+
418
+ Raises
419
+ ------
420
+ Exception
421
+ Prints an error message if the specified frames cannot be loaded or if there is a mismatch between
422
+ the provided experiment channel information and the stack format.
423
+
424
+ Notes
425
+ -----
426
+ - The function uses scikit-image for reading frames and supports multi-frame TIFF stacks.
427
+ - Normalization and scaling are optional and can be customized through function parameters.
428
+ - A workaround is implemented for frames with uniform pixel values to prevent normalization errors by
429
+ adding a 'fake' pixel.
430
+
431
+ Examples
432
+ --------
433
+ >>> frames = load_frames([0, 1, 2], '/path/to/stack.tif', scale=0.5, normalize_input=True, dtype=np.uint8)
434
+ # Loads the first three frames from '/path/to/stack.tif', normalizes them, rescales by a factor of 0.5,
435
+ # and converts them to uint8 data type.
436
+
437
+ """
438
+
439
+ try:
440
+ import warnings
441
+
442
+ with warnings.catch_warnings():
443
+ warnings.filterwarnings(
444
+ "ignore", message=".*MMStack series is missing files.*"
445
+ )
446
+ frames = imageio.imread(stack_path, key=img_nums)
447
+ except Exception as e:
448
+ print(
449
+ f"Error in loading the frame {img_nums} {e}. Please check that the experiment channel information is consistent with the movie being read."
450
+ )
451
+ return None
452
+ try:
453
+ if np.any(np.isinf(frames)):
454
+ frames = frames.astype(float)
455
+ frames[np.isinf(frames)] = np.nan
456
+ except Exception as e:
457
+ print(e)
458
+
459
+ frames = _rearrange_multichannel_frame(frames)
460
+
461
+ if normalize_input:
462
+ frames = normalize_multichannel(frames.astype(float), **normalize_kwargs)
463
+
464
+ if scale is not None:
465
+ frames = zoom_multiframes(frames.astype(float), scale)
466
+
467
+ # add a fake pixel to prevent auto normalization errors on images that are uniform
468
+ frames = _fix_no_contrast(frames)
469
+
470
+ return frames # .astype(dtype)
471
+
472
+
473
+ def _rearrange_multichannel_frame(
474
+ frame: np.ndarray, n_channels: Optional[int] = None
475
+ ) -> np.ndarray:
476
+ """
477
+ Rearranges the axes of a multi-channel frame to ensure the channel axis is at the end.
478
+
479
+ This function standardizes the input frame to ensure that the channel axis (if present)
480
+ is moved to the last position. For 2D frames, it adds a singleton channel axis at the end.
481
+
482
+ Parameters
483
+ ----------
484
+ frame : ndarray
485
+ The input frame to be rearranged. Can be 2D or 3D.
486
+ - If 3D, the function identifies the channel axis (assumed to be the axis with the smallest size)
487
+ and moves it to the last position.
488
+ - If 2D, the function adds a singleton channel axis to make it compatible with 3D processing.
489
+
490
+ Returns
491
+ -------
492
+ ndarray
493
+ The rearranged frame with the channel axis at the end.
494
+ - For 3D frames, the output shape will have the channel axis as the last dimension.
495
+ - For 2D frames, the output will have shape `(H, W, 1)` where `H` and `W` are the height and width of the frame.
496
+
497
+ Notes
498
+ -----
499
+ - This function assumes that in a 3D input, the channel axis is the one with the smallest size.
500
+ - For 2D frames, this function ensures compatibility with multi-channel processing pipelines by
501
+ adding a singleton dimension for the channel axis.
502
+
503
+ Examples
504
+ --------
505
+ Rearranging a 3D multi-channel frame:
506
+ >>> frame = np.zeros((10, 10, 3)) # Already channel-last
507
+ >>> _rearrange_multichannel_frame(frame).shape
508
+ (10, 10, 3)
509
+
510
+ Rearranging a 3D frame with channel axis not at the end:
511
+ >>> frame = np.zeros((3, 10, 10)) # Channel-first
512
+ >>> _rearrange_multichannel_frame(frame).shape
513
+ (10, 10, 3)
514
+
515
+ Converting a 2D frame to have a channel axis:
516
+ >>> frame = np.zeros((10, 10)) # Grayscale image
517
+ >>> _rearrange_multichannel_frame(frame).shape
518
+ (10, 10, 1)
519
+ """
520
+
521
+ if frame.ndim == 3:
522
+ # Systematically move channel axis to the end
523
+ if n_channels is not None and n_channels in list(frame.shape):
524
+ channel_axis = list(frame.shape).index(n_channels)
525
+ else:
526
+ channel_axis = np.argmin(frame.shape)
527
+ frame = np.moveaxis(frame, channel_axis, -1)
528
+
529
+ if frame.ndim == 2:
530
+ frame = frame[:, :, np.newaxis]
531
+
532
+ return frame
533
+
534
+
535
+ def zoom_multiframes(frames: np.ndarray, zoom_factor: float) -> np.ndarray:
536
+ """
537
+ Applies zooming to each frame (channel) in a multi-frame image.
538
+
539
+ This function resizes each channel of a multi-frame image independently using a specified zoom factor.
540
+ The zoom is applied using spline interpolation of the specified order, and the channels are combined
541
+ back into the original format.
542
+
543
+ Parameters
544
+ ----------
545
+ frames : ndarray
546
+ A multi-frame image with dimensions `(height, width, channels)`. The last axis represents different
547
+ channels.
548
+ zoom_factor : float
549
+ The zoom factor to apply to each channel. Values greater than 1 increase the size, and values
550
+ between 0 and 1 decrease the size.
551
+
552
+ Returns
553
+ -------
554
+ ndarray
555
+ A new multi-frame image with the same number of channels as the input, but with the height and width
556
+ scaled by the zoom factor.
557
+
558
+ Notes
559
+ -----
560
+ - The function uses spline interpolation (order 3) for resizing, which provides smooth results.
561
+ - `prefilter=False` is used to prevent additional filtering during the zoom operation.
562
+ - The function assumes that the input is in `height x width x channels` format, with channels along the
563
+ last axis.
564
+ """
565
+
566
+ from scipy.ndimage import zoom
567
+
568
+ frames = [
569
+ zoom(
570
+ frames[:, :, c].copy(), [zoom_factor, zoom_factor], order=3, prefilter=False
571
+ )
572
+ for c in range(frames.shape[-1])
573
+ ]
574
+ frames = np.moveaxis(frames, 0, -1)
575
+ return frames
576
+
577
+
578
+ def fix_missing_labels(position, population="target", prefix="Aligned"):
579
+ """
580
+ Fix missing label files by creating empty label images for frames that do not have corresponding label files.
581
+
582
+ This function locates missing label files in a sequence of frames and creates empty labels (filled with zeros)
583
+ for the frames that are missing. The function works for two types of populations: 'target' or 'effector'.
584
+
585
+ Parameters
586
+ ----------
587
+ position : str
588
+ The file path to the folder containing the images/label files. This is the root directory where
589
+ the label files are expected to be found.
590
+ population : str, optional
591
+ Specifies whether to look for 'target' or 'effector' labels. Accepts 'target' or 'effector'
592
+ as valid values. Default is 'target'.
593
+ prefix : str, optional
594
+ The prefix used to locate the image stack (default is 'Aligned').
595
+
596
+ Returns
597
+ -------
598
+ None
599
+ The function creates new label files in the corresponding folder for any frames missing label files.
600
+
601
+ """
602
+
603
+ if not position.endswith(os.sep):
604
+ position += os.sep
605
+
606
+ stack = locate_stack(position, prefix=prefix)
607
+ template = np.zeros((stack[0].shape[0], stack[0].shape[1]), dtype=int)
608
+ all_frames = np.arange(len(stack))
609
+
610
+ if population.lower() == "target" or population.lower() == "targets":
611
+ label_path = natsorted(
612
+ glob(position + os.sep.join(["labels_targets", "*.tif"]))
613
+ )
614
+ path = position + os.sep + "labels_targets"
615
+ elif population.lower() == "effector" or population.lower() == "effectors":
616
+ label_path = natsorted(
617
+ glob(position + os.sep.join(["labels_effectors", "*.tif"]))
618
+ )
619
+ path = position + os.sep + "labels_effectors"
620
+ else:
621
+ label_path = natsorted(
622
+ glob(position + os.sep.join([f"labels_{population}", "*.tif"]))
623
+ )
624
+ path = position + os.sep + f"labels_{population}"
625
+
626
+ if label_path != []:
627
+ # path = os.path.split(label_path[0])[0]
628
+ int_valid = [int(lbl.split(os.sep)[-1].split(".")[0]) for lbl in label_path]
629
+ to_create = [x for x in all_frames if x not in int_valid]
630
+ else:
631
+ to_create = all_frames
632
+ to_create = [str(x).zfill(4) + ".tif" for x in to_create]
633
+ for file in to_create:
634
+ save_tiff_imagej_compatible(
635
+ os.sep.join([path, file]), template.astype(np.int16), axes="YX"
636
+ )
637
+ # imwrite(os.sep.join([path, file]), template.astype(int))
638
+
639
+
640
+ def _get_img_num_per_channel(channels_indices, len_movie, nbr_channels):
641
+ """
642
+ Calculates the image frame numbers for each specified channel in a multi-channel movie.
643
+
644
+ Given the indices of channels of interest, the total length of the movie, and the number of channels,
645
+ this function computes the frame numbers corresponding to each channel throughout the movie. If a
646
+ channel index is specified as None, it assigns a placeholder value to indicate no frames for that channel.
647
+
648
+ Parameters
649
+ ----------
650
+ channels_indices : list of int or None
651
+ A list containing the indices of channels for which to calculate frame numbers. If an index is None,
652
+ it is interpreted as a channel with no frames to be processed.
653
+ len_movie : int
654
+ The total number of frames in the movie across all channels.
655
+ nbr_channels : int
656
+ The total number of channels in the movie.
657
+
658
+ Returns
659
+ -------
660
+ ndarray
661
+ A 2D numpy array where each row corresponds to a channel specified in `channels_indices` and contains
662
+ the frame numbers for that channel throughout the movie. If a channel index is None, the corresponding
663
+ row contains placeholder values (-1).
664
+
665
+ Notes
666
+ -----
667
+ - The function assumes that frames in the movie are interleaved by channel, with frames for each channel
668
+ appearing in a regular sequence throughout the movie.
669
+ - This utility is particularly useful for multi-channel time-lapse movies where analysis or processing
670
+ needs to be performed on a per-channel basis.
671
+
672
+ Examples
673
+ --------
674
+ >>> channels_indices = [0] # Indices for channels 1, 3, and a non-existing channel
675
+ >>> len_movie = 10 # Total frames for each channel
676
+ >>> nbr_channels = 3 # Total channels in the movie
677
+ >>> img_num_per_channel = _get_img_num_per_channel(channels_indices, len_movie, nbr_channels)
678
+ >>> print(img_num_per_channel)
679
+ # array([[ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27]])
680
+
681
+ >>> channels_indices = [1,2] # Indices for channels 1, 3, and a non-existing channel
682
+ >>> len_movie = 10 # Total frames for each channel
683
+ >>> nbr_channels = 3 # Total channels in the movie
684
+ >>> img_num_per_channel = _get_img_num_per_channel(channels_indices, len_movie, nbr_channels)
685
+ >>> print(img_num_per_channel)
686
+ # array([[ 1, 4, 7, 10, 13, 16, 19, 22, 25, 28],
687
+ # [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29]])
688
+
689
+ """
690
+
691
+ if isinstance(channels_indices, (int, np.int_)):
692
+ channels_indices = [channels_indices]
693
+
694
+ len_movie = int(len_movie)
695
+ nbr_channels = int(nbr_channels)
696
+
697
+ img_num_all_channels = []
698
+ for c in channels_indices:
699
+ if c is not None:
700
+ indices = np.arange(len_movie * nbr_channels)[c::nbr_channels]
701
+ else:
702
+ indices = [-1] * len_movie
703
+ img_num_all_channels.append(indices)
704
+ img_num_all_channels = np.array(img_num_all_channels, dtype=int)
705
+
706
+ return img_num_all_channels
707
+
708
+
709
+ def _extract_channel_indices(channels, required_channels):
710
+ """
711
+ Extracts the indices of required channels from a list of available channels.
712
+
713
+ This function is designed to match the channels required by a model or analysis process with the channels
714
+ present in the dataset. It returns the indices of the required channels within the list of available channels.
715
+ If the required channels are not found among the available channels, the function prints an error message and
716
+ returns None.
717
+
718
+ Parameters
719
+ ----------
720
+ channels : list of str or None
721
+ A list containing the names of the channels available in the dataset. If None, it is assumed that the
722
+ dataset channels are in the same order as the required channels.
723
+ required_channels : list of str
724
+ A list containing the names of the channels required by the model or analysis process.
725
+
726
+ Returns
727
+ -------
728
+ ndarray or None
729
+ An array of indices indicating the positions of the required channels within the list of available
730
+ channels. Returns None if there is a mismatch between required and available channels.
731
+
732
+ Notes
733
+ -----
734
+ - The function is useful for preprocessing steps where specific channels of multi-channel data are needed
735
+ for further analysis or model input.
736
+ - In cases where `channels` is None, indicating that the dataset does not specify channel names, the function
737
+ assumes that the dataset's channel order matches the order of `required_channels` and returns an array of
738
+ indices based on this assumption.
739
+
740
+ Examples
741
+ --------
742
+ >>> available_channels = ['DAPI', 'GFP', 'RFP']
743
+ >>> required_channels = ['GFP', 'RFP']
744
+ >>> indices = _extract_channel_indices(available_channels, required_channels)
745
+ >>> print(indices)
746
+ # [1, 2]
747
+
748
+ >>> indices = _extract_channel_indices(None, required_channels)
749
+ >>> print(indices)
750
+ # [0, 1]
751
+ """
752
+
753
+ channel_indices = []
754
+ for c in required_channels:
755
+ if c != "None" and c is not None:
756
+ try:
757
+ ch_idx = channels.index(c)
758
+ channel_indices.append(ch_idx)
759
+ except Exception as e:
760
+ channel_indices.append(None)
761
+ else:
762
+ channel_indices.append(None)
763
+
764
+ return channel_indices
765
+
766
+
767
+ def load_image_dataset(
768
+ datasets, channels, train_spatial_calibration=None, mask_suffix="labelled"
769
+ ):
770
+ """
771
+ Loads image and corresponding mask datasets, optionally applying spatial calibration adjustments.
772
+
773
+ This function iterates over specified datasets, loading image and mask pairs based on provided channels
774
+ and adjusting images according to a specified spatial calibration factor. It supports loading images with
775
+ multiple channels and applies necessary transformations to match the training spatial calibration.
776
+
777
+ Parameters
778
+ ----------
779
+ datasets : list of str
780
+ A list of paths to the datasets containing the images and masks.
781
+ channels : str or list of str
782
+ The channel(s) to be loaded from the images. If a string is provided, it is converted into a list.
783
+ train_spatial_calibration : float, optional
784
+ The spatial calibration (e.g., micrometers per pixel) used during model training. If provided, images
785
+ will be rescaled to match this calibration. Default is None, indicating no rescaling is applied.
786
+ mask_suffix : str, optional
787
+ The suffix used to identify mask files corresponding to the images. Default is 'labelled'.
788
+
789
+ Returns
790
+ -------
791
+ tuple of lists
792
+ A tuple containing two lists: `X` for images and `Y` for corresponding masks. Both lists contain
793
+ numpy arrays of loaded and optionally transformed images and masks.
794
+
795
+ Raises
796
+ ------
797
+ AssertionError
798
+ If the provided `channels` argument is not a list or if the number of loaded images does not match
799
+ the number of loaded masks.
800
+
801
+ Notes
802
+ -----
803
+ - The function assumes that mask filenames are derived from image filenames by appending a `mask_suffix`
804
+ before the file extension.
805
+ - Spatial calibration adjustment involves rescaling the images and masks to match the `train_spatial_calibration`.
806
+ - Only images with a corresponding mask and a valid configuration file specifying channel indices and
807
+ spatial calibration are loaded.
808
+ - The image samples must have at least one channel in common with the required channels to be accepted. The missing
809
+ channels are passed as black frames.
810
+
811
+ Examples
812
+ --------
813
+ >>> datasets = ['/path/to/dataset1', '/path/to/dataset2']
814
+ >>> channels = ['DAPI', 'GFP']
815
+ >>> X, Y = load_image_dataset(datasets, channels, train_spatial_calibration=0.65)
816
+ # Loads DAPI and GFP channels from specified datasets, rescaling images to match a spatial calibration of 0.65.
817
+ """
818
+
819
+ from scipy.ndimage import zoom
820
+
821
+ if isinstance(channels, str):
822
+ channels = [channels]
823
+
824
+ assert isinstance(channels, list), "Please provide a list of channels. Abort."
825
+
826
+ X = []
827
+ Y = []
828
+ files = []
829
+
830
+ for ds in datasets:
831
+ print(f"Loading data from dataset {ds}...")
832
+ if not ds.endswith(os.sep):
833
+ ds += os.sep
834
+ img_paths = list(
835
+ set(glob(ds + "*.tif")) - set(glob(ds + f"*_{mask_suffix}.tif"))
836
+ )
837
+ for im in img_paths:
838
+ print(f"{im=}")
839
+ mask_path = os.sep.join(
840
+ [
841
+ os.path.split(im)[0],
842
+ os.path.split(im)[-1].replace(".tif", f"_{mask_suffix}.tif"),
843
+ ]
844
+ )
845
+ if os.path.exists(mask_path):
846
+ # load image and mask
847
+ image = imread(im)
848
+ if image.ndim == 2:
849
+ image = image[np.newaxis]
850
+ if image.ndim > 3:
851
+ print("Invalid image shape, skipping")
852
+ continue
853
+ mask = imread(mask_path)
854
+ config_path = im.replace(".tif", ".json")
855
+ if os.path.exists(config_path):
856
+ # Load config
857
+ with open(config_path, "r") as f:
858
+ config = json.load(f)
859
+
860
+ existing_channels = config["channels"]
861
+ intersection = list(
862
+ set(list(channels)) & set(list(existing_channels))
863
+ )
864
+ print(f"{existing_channels=} {intersection=}")
865
+ if len(intersection) == 0:
866
+ print(
867
+ "Channels could not be found in the config... Skipping image."
868
+ )
869
+ continue
870
+ else:
871
+ ch_idx = []
872
+ for c in channels:
873
+ if c in existing_channels:
874
+ idx = existing_channels.index(c)
875
+ ch_idx.append(idx)
876
+ else:
877
+ # For None or missing channel pass black frame
878
+ ch_idx.append(np.nan)
879
+ im_calib = config["spatial_calibration"]
880
+
881
+ ch_idx = np.array(ch_idx)
882
+ ch_idx_safe = np.copy(ch_idx)
883
+ ch_idx_safe[ch_idx_safe != ch_idx_safe] = 0
884
+ ch_idx_safe = ch_idx_safe.astype(int)
885
+
886
+ image = image[ch_idx_safe]
887
+ image[np.where(ch_idx != ch_idx)[0], :, :] = 0
888
+
889
+ image = np.moveaxis(image, 0, -1)
890
+ assert (
891
+ image.ndim == 3
892
+ ), "The image has a wrong number of dimensions. Abort."
893
+
894
+ if im_calib != train_spatial_calibration:
895
+ factor = im_calib / train_spatial_calibration
896
+ image = np.moveaxis(
897
+ [
898
+ zoom(
899
+ image[:, :, c].astype(float).copy(),
900
+ [factor, factor],
901
+ order=3,
902
+ prefilter=False,
903
+ )
904
+ for c in range(image.shape[-1])
905
+ ],
906
+ 0,
907
+ -1,
908
+ ) # zoom(image, [factor,factor,1], order=3)
909
+ mask = zoom(mask, [factor, factor], order=0)
910
+
911
+ X.append(image)
912
+ Y.append(mask)
913
+
914
+ # fig,ax = plt.subplots(1,image.shape[-1]+1)
915
+ # for k in range(image.shape[-1]):
916
+ # ax[k].imshow(image[:,:,k],cmap='gray')
917
+ # ax[image.shape[-1]].imshow(mask)
918
+ # plt.pause(1)
919
+ # plt.close()
920
+
921
+ files.append(im)
922
+
923
+ assert len(X) == len(
924
+ Y
925
+ ), "The number of images does not match with the number of masks... Abort."
926
+ return X, Y, files