celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +403 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/downloader.py +137 -0
  81. celldetective/processes/measure_cells.py +565 -0
  82. celldetective/processes/segment_cells.py +760 -0
  83. celldetective/processes/track_cells.py +435 -0
  84. celldetective/processes/train_segmentation_model.py +694 -0
  85. celldetective/processes/train_signal_model.py +265 -0
  86. celldetective/processes/unified_process.py +292 -0
  87. celldetective/regionprops/_regionprops.py +358 -317
  88. celldetective/relative_measurements.py +987 -710
  89. celldetective/scripts/measure_cells.py +313 -212
  90. celldetective/scripts/measure_relative.py +90 -46
  91. celldetective/scripts/segment_cells.py +165 -104
  92. celldetective/scripts/segment_cells_thresholds.py +96 -68
  93. celldetective/scripts/track_cells.py +198 -149
  94. celldetective/scripts/train_segmentation_model.py +324 -201
  95. celldetective/scripts/train_signal_model.py +87 -45
  96. celldetective/segmentation.py +844 -749
  97. celldetective/signals.py +3514 -2861
  98. celldetective/tracking.py +30 -15
  99. celldetective/utils/__init__.py +0 -0
  100. celldetective/utils/cellpose_utils/__init__.py +133 -0
  101. celldetective/utils/color_mappings.py +42 -0
  102. celldetective/utils/data_cleaning.py +630 -0
  103. celldetective/utils/data_loaders.py +450 -0
  104. celldetective/utils/dataset_helpers.py +207 -0
  105. celldetective/utils/downloaders.py +235 -0
  106. celldetective/utils/event_detection/__init__.py +8 -0
  107. celldetective/utils/experiment.py +1782 -0
  108. celldetective/utils/image_augmenters.py +308 -0
  109. celldetective/utils/image_cleaning.py +74 -0
  110. celldetective/utils/image_loaders.py +926 -0
  111. celldetective/utils/image_transforms.py +335 -0
  112. celldetective/utils/io.py +62 -0
  113. celldetective/utils/mask_cleaning.py +348 -0
  114. celldetective/utils/mask_transforms.py +5 -0
  115. celldetective/utils/masks.py +184 -0
  116. celldetective/utils/maths.py +351 -0
  117. celldetective/utils/model_getters.py +325 -0
  118. celldetective/utils/model_loaders.py +296 -0
  119. celldetective/utils/normalization.py +380 -0
  120. celldetective/utils/parsing.py +465 -0
  121. celldetective/utils/plots/__init__.py +0 -0
  122. celldetective/utils/plots/regression.py +53 -0
  123. celldetective/utils/resources.py +34 -0
  124. celldetective/utils/stardist_utils/__init__.py +104 -0
  125. celldetective/utils/stats.py +90 -0
  126. celldetective/utils/types.py +21 -0
  127. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
  128. celldetective-1.5.0b1.dist-info/RECORD +187 -0
  129. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
  130. tests/gui/test_new_project.py +129 -117
  131. tests/gui/test_project.py +127 -79
  132. tests/test_filters.py +39 -15
  133. tests/test_notebooks.py +8 -0
  134. tests/test_tracking.py +232 -13
  135. tests/test_utils.py +123 -77
  136. celldetective/gui/base_components.py +0 -23
  137. celldetective/gui/layouts.py +0 -1602
  138. celldetective/gui/processes/compute_neighborhood.py +0 -594
  139. celldetective/gui/processes/downloader.py +0 -111
  140. celldetective/gui/processes/measure_cells.py +0 -360
  141. celldetective/gui/processes/segment_cells.py +0 -499
  142. celldetective/gui/processes/track_cells.py +0 -303
  143. celldetective/gui/processes/train_segmentation_model.py +0 -270
  144. celldetective/gui/processes/train_signal_model.py +0 -108
  145. celldetective/gui/table_ops/merge_groups.py +0 -118
  146. celldetective/gui/viewers.py +0 -1354
  147. celldetective/io.py +0 -3663
  148. celldetective/utils.py +0 -3108
  149. celldetective-1.4.2.dist-info/RECORD +0 -123
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
  152. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,694 @@
1
+ from distutils.dir_util import copy_tree
2
+ from multiprocessing import Process
3
+ import time
4
+ import os
5
+ import shutil
6
+ from glob import glob
7
+ import json
8
+ import logging
9
+ import re
10
+
11
+ from tensorflow.python.keras.callbacks import Callback
12
+ from tqdm import tqdm
13
+ import numpy as np
14
+ import random
15
+
16
+ from celldetective.utils.image_augmenters import augmenter
17
+ from celldetective.utils.image_loaders import load_image_dataset
18
+ from celldetective.utils.image_cleaning import interpolate_nan
19
+ from celldetective.utils.normalization import normalize_multichannel
20
+ from celldetective.utils.mask_cleaning import fill_label_holes
21
+ from art import tprint
22
+ from csbdeep.utils import save_json
23
+ from celldetective import get_logger
24
+
25
+ logger = get_logger()
26
+
27
+
28
+ class ProgressCallback(Callback):
29
+
30
+ def __init__(self, queue=None, epochs=100, stop_event=None):
31
+ super().__init__()
32
+ self.queue = queue
33
+ self.epochs = epochs
34
+ self.stop_event = stop_event
35
+ self.t0 = time.time()
36
+
37
+ def on_epoch_end(self, epoch, logs=None):
38
+
39
+ if self.stop_event and self.stop_event.is_set():
40
+ self.model.stop_training = True
41
+ return
42
+
43
+ if logs is None:
44
+ logs = {}
45
+
46
+ # Send signal for progress bar
47
+ sum_done = (epoch + 1) / self.epochs * 100
48
+ mean_exec_per_step = (time.time() - self.t0) / (epoch + 1)
49
+ pred_time = (self.epochs - (epoch + 1)) * mean_exec_per_step
50
+ if self.queue is not None:
51
+ self.queue.put([sum_done, pred_time])
52
+
53
+ # Plot update
54
+ metrics = {k: v for k, v in logs.items() if not k.startswith("val_")}
55
+ val_metrics = {k: v for k, v in logs.items() if k.startswith("val_")}
56
+
57
+ plot_data = {
58
+ "epoch": epoch,
59
+ "metrics": metrics,
60
+ "val_metrics": val_metrics,
61
+ "model_name": "StarDist",
62
+ "total_epochs": self.epochs,
63
+ }
64
+ self.queue.put({"plot_data": plot_data})
65
+
66
+
67
+ class QueueLoggingHandler(logging.Handler):
68
+ def __init__(self, queue, total_epochs, stop_event=None):
69
+ super().__init__()
70
+ self.queue = queue
71
+ self.total_epochs = total_epochs
72
+ self.stop_event = stop_event
73
+ self.epoch_pattern = re.compile(
74
+ r"Epoch (\d+), Time .*, Loss ([\d\.eE\-\+naninf]+)(?:, Loss Test ([\d\.eE\-\+naninf]+))?",
75
+ re.IGNORECASE,
76
+ )
77
+ self.t0 = time.time()
78
+
79
+ def emit(self, record):
80
+ if self.stop_event and self.stop_event.is_set():
81
+ # Can't easily stop cellpose_utils loop from here without raising exception or hacking
82
+ # raising exception might be safest to exit training loop
83
+ raise InterruptedError("Training interrupted")
84
+
85
+ log_entry = self.format(record)
86
+ match = self.epoch_pattern.search(log_entry)
87
+ if match:
88
+ epoch = int(match.group(1))
89
+ loss = float(match.group(2))
90
+ val_loss = float(match.group(3)) if match.group(3) else None
91
+
92
+ sum_done = (epoch + 1) / self.total_epochs * 100
93
+ mean_exec_per_step = (time.time() - self.t0) / (epoch + 1)
94
+ pred_time = (self.total_epochs - (epoch + 1)) * mean_exec_per_step
95
+
96
+ self.queue.put([sum_done, pred_time])
97
+
98
+ metrics = {"loss": loss}
99
+ val_metrics = {}
100
+ if val_loss is not None:
101
+ val_metrics["val_loss"] = val_loss
102
+
103
+ plot_data = {
104
+ "epoch": epoch,
105
+ "metrics": metrics,
106
+ "val_metrics": val_metrics,
107
+ "model_name": "Cellpose",
108
+ "total_epochs": self.total_epochs,
109
+ }
110
+ self.queue.put({"plot_data": plot_data})
111
+
112
+
113
+ class TrainSegModelProcess(Process):
114
+
115
+ def __init__(self, queue=None, process_args=None, *args, **kwargs):
116
+
117
+ super().__init__(*args, **kwargs)
118
+
119
+ self.queue = queue
120
+
121
+ if process_args is not None:
122
+ for key, value in process_args.items():
123
+ setattr(self, key, value)
124
+
125
+ tprint("Train segmentation")
126
+ self.read_instructions()
127
+ self.extract_training_params()
128
+ self.load_dataset()
129
+ self.split_test_train()
130
+
131
+ self.sum_done = 0
132
+ self.t0 = time.time()
133
+
134
+ def read_instructions(self):
135
+
136
+ if os.path.exists(self.instructions):
137
+ with open(self.instructions, "r") as f:
138
+ self.training_instructions = json.load(f)
139
+ else:
140
+ logger.error("Training instructions could not be found. Abort.")
141
+ self.abort_process()
142
+
143
+ def run(self):
144
+
145
+ self.queue.put("Loading dataset...")
146
+
147
+ if self.model_type == "cellpose":
148
+ self.train_cellpose_model()
149
+ elif self.model_type == "stardist":
150
+ self.train_stardist_model()
151
+
152
+ self.queue.put("finished")
153
+ self.queue.close()
154
+
155
+ def train_stardist_model(self):
156
+
157
+ from stardist import calculate_extents, gputools_available
158
+ from stardist.models import Config2D, StarDist2D
159
+
160
+ n_rays = 32
161
+ logger.info(gputools_available())
162
+
163
+ n_channel = self.X_trn[0].shape[-1]
164
+
165
+ # Predict on subsampled grid for increased efficiency and larger field of view
166
+ grid = (2, 2)
167
+ conf = Config2D(
168
+ n_rays=n_rays,
169
+ grid=grid,
170
+ use_gpu=self.use_gpu,
171
+ n_channel_in=n_channel,
172
+ train_learning_rate=self.learning_rate,
173
+ train_patch_size=(256, 256),
174
+ train_epochs=self.epochs,
175
+ train_reduce_lr={"factor": 0.1, "patience": 30, "min_delta": 0},
176
+ train_batch_size=self.batch_size,
177
+ train_steps_per_epoch=int(self.augmentation_factor * len(self.X_trn)),
178
+ )
179
+
180
+ if self.use_gpu:
181
+ from csbdeep.utils.tf import limit_gpu_memory
182
+
183
+ limit_gpu_memory(None, allow_growth=True)
184
+
185
+ if self.pretrained is None:
186
+ model = StarDist2D(
187
+ conf, name=self.model_name, basedir=self.target_directory
188
+ )
189
+ else:
190
+ os.rename(
191
+ self.instructions,
192
+ os.sep.join([self.target_directory, self.model_name, "temp.json"]),
193
+ )
194
+ copy_tree(
195
+ self.pretrained, os.sep.join([self.target_directory, self.model_name])
196
+ )
197
+
198
+ if os.path.exists(
199
+ os.sep.join(
200
+ [
201
+ self.target_directory,
202
+ self.model_name,
203
+ "training_instructions.json",
204
+ ]
205
+ )
206
+ ):
207
+ os.remove(
208
+ os.sep.join(
209
+ [
210
+ self.target_directory,
211
+ self.model_name,
212
+ "training_instructions.json",
213
+ ]
214
+ )
215
+ )
216
+ if os.path.exists(
217
+ os.sep.join(
218
+ [self.target_directory, self.model_name, "config_input.json"]
219
+ )
220
+ ):
221
+ os.remove(
222
+ os.sep.join(
223
+ [self.target_directory, self.model_name, "config_input.json"]
224
+ )
225
+ )
226
+ if os.path.exists(
227
+ os.sep.join([self.target_directory, self.model_name, "logs" + os.sep])
228
+ ):
229
+ shutil.rmtree(
230
+ os.sep.join([self.target_directory, self.model_name, "logs"])
231
+ )
232
+ os.rename(
233
+ os.sep.join([self.target_directory, self.model_name, "temp.json"]),
234
+ os.sep.join(
235
+ [
236
+ self.target_directory,
237
+ self.model_name,
238
+ "training_instructions.json",
239
+ ]
240
+ ),
241
+ )
242
+
243
+ # shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
244
+ model = StarDist2D(
245
+ None, name=self.model_name, basedir=self.target_directory
246
+ )
247
+ model.config.train_epochs = self.epochs
248
+ model.config.train_batch_size = min(len(self.X_trn), self.batch_size)
249
+ model.config.train_learning_rate = (
250
+ self.learning_rate
251
+ ) # perf seems bad if lr is changed in transfer
252
+ model.config.use_gpu = self.use_gpu
253
+ model.config.train_reduce_lr = {
254
+ "factor": 0.1,
255
+ "patience": 10,
256
+ "min_delta": 0,
257
+ }
258
+ logger.info(f"{model.config=}")
259
+
260
+ save_json(
261
+ vars(model.config),
262
+ os.sep.join([self.target_directory, self.model_name, "config.json"]),
263
+ )
264
+
265
+ if self.pretrained is not None:
266
+ logger.info("Freezing encoder layers for StarDist model...")
267
+ mod = model.keras_model
268
+ encoder_depth = len(mod.layers) // 2
269
+
270
+ for layer in mod.layers[:encoder_depth]:
271
+ layer.trainable = False
272
+
273
+ # Keep decoder trainable
274
+ for layer in mod.layers[encoder_depth:]:
275
+ layer.trainable = True
276
+
277
+ median_size = calculate_extents(list(self.Y_trn), np.mean)
278
+ fov = np.array(model._axes_tile_overlap("YX"))
279
+ logger.info(f"median object size: {median_size}")
280
+ logger.info(f"network field of view : {fov}")
281
+ if any(median_size > fov):
282
+ logger.warning(
283
+ "WARNING: median object size larger than field of view of the neural network."
284
+ )
285
+
286
+ import sys
287
+
288
+ class StreamToQueue:
289
+ def __init__(self, queue, total_epochs, original_stream, stop_event=None):
290
+ self.queue = queue
291
+ self.total_epochs = total_epochs
292
+ self.original_stream = original_stream
293
+ self.stop_event = stop_event
294
+ self.epoch_pattern = re.compile(r"Epoch (\d+)/(\d+)")
295
+ # Generic pattern to capture "key: value" pairs
296
+ self.metric_pattern = re.compile(
297
+ r"([\w_]+)\s*:\s*([\d\.eE\-\+naninf]+)"
298
+ )
299
+ self.current_epoch = 0
300
+ self.t0 = time.time()
301
+ self.buffer = ""
302
+
303
+ def write(self, message):
304
+ if self.stop_event and self.stop_event.is_set():
305
+ raise InterruptedError("Training interrupted by user")
306
+
307
+ self.original_stream.write(message)
308
+ self.original_stream.flush() # Ensure immediate display
309
+ self.buffer += message
310
+ if "\n" in message or "\r" in message:
311
+ self._parse_buffer()
312
+
313
+ def flush(self):
314
+ self.original_stream.flush()
315
+
316
+ def _parse_buffer(self):
317
+ lines = re.split(r"[\r\n]+", self.buffer)
318
+ # Keep the last incomplete part in buffer
319
+ if not (self.buffer.endswith("\n") or self.buffer.endswith("\r")):
320
+ self.buffer = lines[-1]
321
+ lines = lines[:-1]
322
+ else:
323
+ self.buffer = ""
324
+
325
+ for line in lines:
326
+ if not line.strip():
327
+ continue
328
+
329
+ # Check for Epoch
330
+ m_epoch = self.epoch_pattern.search(line)
331
+ if m_epoch:
332
+ self.current_epoch = int(m_epoch.group(1))
333
+ # Put progress?
334
+ sum_done = (self.current_epoch - 1) / self.total_epochs * 100
335
+ self.queue.put(
336
+ [sum_done, 0]
337
+ ) # Time estimation handled by GUI or ignored
338
+ continue
339
+
340
+ # Capture all metrics in the line
341
+ found_metrics = self.metric_pattern.findall(line)
342
+ if found_metrics:
343
+ metrics = {}
344
+ val_metrics = {}
345
+
346
+ for key, val_str in found_metrics:
347
+ try:
348
+ val = float(val_str)
349
+ if key.startswith("val_"):
350
+ val_metrics[key] = val
351
+ else:
352
+ metrics[key] = val
353
+ except ValueError:
354
+ pass
355
+
356
+ # Only send plot data if we have validation metrics (indicates end of epoch)
357
+ if metrics and val_metrics:
358
+ plot_data = {
359
+ "epoch": self.current_epoch,
360
+ "metrics": metrics,
361
+ "val_metrics": val_metrics,
362
+ "model_name": "StarDist",
363
+ "total_epochs": self.total_epochs,
364
+ }
365
+ self.queue.put({"plot_data": plot_data})
366
+
367
+ # Redirect stdout/stderr to capture Keras output
368
+ original_stdout = sys.stdout
369
+ original_stderr = sys.stderr
370
+ stream_parser = StreamToQueue(
371
+ self.queue,
372
+ self.epochs,
373
+ original_stdout,
374
+ stop_event=self.stop_event if hasattr(self, "stop_event") else None,
375
+ )
376
+ sys.stdout = stream_parser
377
+ sys.stderr = stream_parser # Keras often prints to stderr
378
+
379
+ try:
380
+ if self.augmentation_factor == 1.0:
381
+ model.train(
382
+ self.X_trn,
383
+ self.Y_trn,
384
+ validation_data=(self.X_val, self.Y_val),
385
+ epochs=self.epochs,
386
+ )
387
+ else:
388
+ model.train(
389
+ self.X_trn,
390
+ self.Y_trn,
391
+ validation_data=(self.X_val, self.Y_val),
392
+ augmenter=augmenter,
393
+ epochs=self.epochs,
394
+ )
395
+ except Exception as e:
396
+ logger.error(f"Error in StarDist training: {e}")
397
+ raise e
398
+ finally:
399
+ sys.stdout = original_stdout
400
+ sys.stderr = original_stderr
401
+
402
+ model.optimize_thresholds(self.X_val, self.Y_val)
403
+
404
+ if isinstance(median_size, (list, np.ndarray)):
405
+ median_size_scalar = np.mean(median_size)
406
+ else:
407
+ median_size_scalar = median_size
408
+
409
+ config_inputs = {
410
+ "channels": self.target_channels,
411
+ "normalization_percentile": self.normalization_percentile,
412
+ "normalization_clip": self.normalization_clip,
413
+ "normalization_values": self.normalization_values,
414
+ "model_type": "stardist",
415
+ "spatial_calibration": self.spatial_calibration,
416
+ "cell_size_um": float(median_size_scalar * self.spatial_calibration),
417
+ "dataset": {"train": self.files_train, "validation": self.files_val},
418
+ }
419
+
420
+ def make_json_safe(obj):
421
+ if isinstance(obj, np.ndarray):
422
+ return obj.tolist()
423
+ if isinstance(obj, (np.int64, np.int32)):
424
+ return int(obj)
425
+ if isinstance(obj, (np.float32, np.float64)):
426
+ return float(obj)
427
+ return str(obj)
428
+
429
+ json_input_config = json.dumps(config_inputs, indent=4, default=make_json_safe)
430
+ with open(
431
+ os.sep.join([self.target_directory, self.model_name, "config_input.json"]),
432
+ "w",
433
+ ) as outfile:
434
+ outfile.write(json_input_config)
435
+
436
+ def train_cellpose_model(self):
437
+
438
+ # do augmentation in place
439
+ X_aug = []
440
+ Y_aug = []
441
+ n_val = max(1, int(round(self.augmentation_factor * len(self.X_trn))))
442
+ indices = random.choices(list(np.arange(len(self.X_trn))), k=n_val)
443
+ logger.info("Performing image augmentation pre-training...")
444
+ for i in tqdm(indices):
445
+ x_aug, y_aug = augmenter(self.X_trn[i], self.Y_trn[i])
446
+ X_aug.append(x_aug)
447
+ Y_aug.append(y_aug)
448
+
449
+ # Channel axis in front for cellpose_utils
450
+ X_aug = [np.moveaxis(x, -1, 0) for x in X_aug]
451
+ self.X_val = [np.moveaxis(x, -1, 0) for x in self.X_val]
452
+ logger.info("number of augmented images: %3d" % len(X_aug))
453
+
454
+ from cellpose.models import CellposeModel
455
+ from cellpose.io import logger_setup
456
+ import torch
457
+
458
+ if not self.use_gpu:
459
+ logger.info("Using CPU for training...")
460
+ device = torch.device("cpu")
461
+ else:
462
+ logger.info("Using GPU for training...")
463
+
464
+ # logger_setup configures console and file handlers for cellpose_utils
465
+ _, log_file = logger_setup()
466
+
467
+ # Get cellpose_utils logger explicitly to ensure we catch all cellpose_utils logs (e.g. from models)
468
+ logger_cellpose = logging.getLogger("cellpose")
469
+
470
+ # Add custom handler
471
+ handler = QueueLoggingHandler(
472
+ self.queue,
473
+ self.epochs,
474
+ stop_event=self.stop_event if hasattr(self, "stop_event") else None,
475
+ )
476
+ handler.setLevel(logging.INFO)
477
+ logger_cellpose.addHandler(handler)
478
+
479
+ try:
480
+ logger.info(f"Pretrained model: {self.pretrained}")
481
+ if self.pretrained is not None:
482
+ pretrained_path = os.sep.join(
483
+ [self.pretrained, os.path.split(self.pretrained)[-1]]
484
+ )
485
+ else:
486
+ pretrained_path = self.pretrained
487
+
488
+ model = CellposeModel(
489
+ gpu=self.use_gpu,
490
+ model_type=None,
491
+ pretrained_model=pretrained_path,
492
+ diam_mean=30.0,
493
+ nchan=X_aug[0].shape[0],
494
+ )
495
+
496
+ if self.pretrained is not None:
497
+ logger.info("Freezing encoder layers for Cellpose model...")
498
+ for param in model.net.downsample.parameters():
499
+ param.requires_grad = False
500
+
501
+ # Optional: freeze style branch
502
+ for param in model.net.make_style.parameters():
503
+ param.requires_grad = False
504
+
505
+ # Keep decoder trainable
506
+ for param in model.net.upsample.parameters():
507
+ param.requires_grad = True
508
+
509
+ # Keep output head trainable
510
+ for param in model.net.output.parameters():
511
+ param.requires_grad = True
512
+
513
+ # Unfreeze all output heads (version-safe)
514
+ output_heads = ["output", "output_conv", "flow", "prob"]
515
+ for head_name in output_heads:
516
+ if hasattr(model.net, head_name):
517
+ for param in getattr(model.net, head_name).parameters():
518
+ param.requires_grad = True
519
+
520
+ model.train(
521
+ train_data=X_aug,
522
+ train_labels=Y_aug,
523
+ normalize=False,
524
+ channels=None,
525
+ batch_size=self.batch_size,
526
+ min_train_masks=1,
527
+ save_path=self.target_directory + os.sep + self.model_name,
528
+ n_epochs=self.epochs,
529
+ model_name=self.model_name,
530
+ learning_rate=self.learning_rate,
531
+ test_data=self.X_val,
532
+ test_labels=self.Y_val,
533
+ )
534
+ except InterruptedError:
535
+ logger.info("Training interrupted.")
536
+ except Exception as e:
537
+ logger.error(f"Error during training: {e}")
538
+ raise e
539
+ finally:
540
+ logger_cellpose.removeHandler(handler)
541
+
542
+ file_to_move = glob(
543
+ os.sep.join([self.target_directory, self.model_name, "models", "*"])
544
+ )[0]
545
+ shutil.move(
546
+ file_to_move,
547
+ os.sep.join([self.target_directory, self.model_name, ""])
548
+ + os.path.split(file_to_move)[-1],
549
+ )
550
+ os.rmdir(os.sep.join([self.target_directory, self.model_name, "models"]))
551
+
552
+ diameter = model.diam_labels
553
+
554
+ if (
555
+ self.pretrained is not None
556
+ and os.path.split(self.pretrained)[-1] == "CP_nuclei"
557
+ ):
558
+ standard_diameter = 17.0
559
+ else:
560
+ standard_diameter = 30.0
561
+
562
+ input_spatial_calibration = (
563
+ self.spatial_calibration
564
+ ) # *diameter / standard_diameter
565
+
566
+ config_inputs = {
567
+ "channels": self.target_channels,
568
+ "diameter": standard_diameter,
569
+ "cellprob_threshold": 0.0,
570
+ "flow_threshold": 0.4,
571
+ "normalization_percentile": self.normalization_percentile,
572
+ "normalization_clip": self.normalization_clip,
573
+ "normalization_values": self.normalization_values,
574
+ "model_type": "cellpose",
575
+ "spatial_calibration": input_spatial_calibration,
576
+ "cell_size_um": round(diameter * input_spatial_calibration, 4),
577
+ "dataset": {"train": self.files_train, "validation": self.files_val},
578
+ }
579
+
580
+ def make_json_safe(obj):
581
+ if isinstance(obj, np.ndarray):
582
+ return obj.tolist()
583
+ if isinstance(obj, (np.int64, np.int32)):
584
+ return int(obj)
585
+ if isinstance(obj, (np.float32, np.float64)):
586
+ return float(obj)
587
+ return str(obj)
588
+
589
+ json_input_config = json.dumps(config_inputs, indent=4, default=make_json_safe)
590
+ with open(
591
+ os.sep.join([self.target_directory, self.model_name, "config_input.json"]),
592
+ "w",
593
+ ) as outfile:
594
+ outfile.write(json_input_config)
595
+
596
+ def split_test_train(self):
597
+
598
+ if not len(self.X) > 1:
599
+ logger.error("Not enough training data")
600
+ self.abort_process()
601
+
602
+ rng = np.random.RandomState()
603
+ ind = rng.permutation(len(self.X))
604
+ n_val = max(1, int(round(self.validation_split * len(ind))))
605
+ ind_train, ind_val = ind[:-n_val], ind[-n_val:]
606
+ self.X_val, self.Y_val = [self.X[i] for i in ind_val], [
607
+ self.Y[i] for i in ind_val
608
+ ]
609
+ self.X_trn, self.Y_trn = [self.X[i] for i in ind_train], [
610
+ self.Y[i] for i in ind_train
611
+ ]
612
+
613
+ self.files_train = [self.filenames[i] for i in ind_train]
614
+ self.files_val = [self.filenames[i] for i in ind_val]
615
+
616
+ logger.info("number of images: %3d" % len(self.X))
617
+ logger.info("- training: %3d" % len(self.X_trn))
618
+ logger.info("- validation: %3d" % len(self.X_val))
619
+
620
+ def extract_training_params(self):
621
+
622
+ self.model_name = self.training_instructions["model_name"]
623
+ self.target_directory = self.training_instructions["target_directory"]
624
+ self.model_type = self.training_instructions["model_type"]
625
+ self.pretrained = self.training_instructions["pretrained"]
626
+
627
+ self.datasets = self.training_instructions["ds"]
628
+
629
+ self.target_channels = self.training_instructions["channel_option"]
630
+ self.normalization_percentile = self.training_instructions[
631
+ "normalization_percentile"
632
+ ]
633
+ self.normalization_clip = self.training_instructions["normalization_clip"]
634
+ self.normalization_values = self.training_instructions["normalization_values"]
635
+ self.spatial_calibration = self.training_instructions["spatial_calibration"]
636
+
637
+ self.validation_split = self.training_instructions["validation_split"]
638
+ self.augmentation_factor = self.training_instructions["augmentation_factor"]
639
+
640
+ self.learning_rate = self.training_instructions["learning_rate"]
641
+ self.epochs = self.training_instructions["epochs"]
642
+ self.batch_size = self.training_instructions["batch_size"]
643
+
644
+ def load_dataset(self):
645
+
646
+ logger.info(f"Datasets: {self.datasets}")
647
+ self.X, self.Y, self.filenames = load_image_dataset(
648
+ self.datasets,
649
+ self.target_channels,
650
+ train_spatial_calibration=self.spatial_calibration,
651
+ mask_suffix="labelled",
652
+ )
653
+ logger.info("Dataset loaded...")
654
+
655
+ self.values = []
656
+ self.percentiles = []
657
+ for k in range(len(self.normalization_percentile)):
658
+ if self.normalization_percentile[k]:
659
+ self.percentiles.append(self.normalization_values[k])
660
+ self.values.append(None)
661
+ else:
662
+ self.percentiles.append(None)
663
+ self.values.append(self.normalization_values[k])
664
+
665
+ self.X = [
666
+ normalize_multichannel(
667
+ x,
668
+ **{
669
+ "percentiles": self.percentiles,
670
+ "values": self.values,
671
+ "clip": self.normalization_clip,
672
+ },
673
+ )
674
+ for x in self.X
675
+ ]
676
+
677
+ for k in range(len(self.X)):
678
+ x = self.X[k].copy()
679
+ x_interp = np.moveaxis(
680
+ [interpolate_nan(x[:, :, c].copy()) for c in range(x.shape[-1])], 0, -1
681
+ )
682
+ self.X[k] = x_interp
683
+
684
+ self.Y = [fill_label_holes(y) for y in tqdm(self.Y)]
685
+
686
+ def end_process(self):
687
+
688
+ self.terminate()
689
+ self.queue.put("finished")
690
+
691
+ def abort_process(self):
692
+
693
+ self.terminate()
694
+ self.queue.put("error")