celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__init__.py +25 -0
- celldetective/__main__.py +62 -43
- celldetective/_version.py +1 -1
- celldetective/extra_properties.py +477 -399
- celldetective/filters.py +192 -97
- celldetective/gui/InitWindow.py +541 -411
- celldetective/gui/__init__.py +0 -15
- celldetective/gui/about.py +44 -39
- celldetective/gui/analyze_block.py +120 -84
- celldetective/gui/base/__init__.py +0 -0
- celldetective/gui/base/channel_norm_generator.py +335 -0
- celldetective/gui/base/components.py +249 -0
- celldetective/gui/base/feature_choice.py +92 -0
- celldetective/gui/base/figure_canvas.py +52 -0
- celldetective/gui/base/list_widget.py +133 -0
- celldetective/gui/{styles.py → base/styles.py} +92 -36
- celldetective/gui/base/utils.py +33 -0
- celldetective/gui/base_annotator.py +900 -767
- celldetective/gui/classifier_widget.py +6 -22
- celldetective/gui/configure_new_exp.py +777 -671
- celldetective/gui/control_panel.py +635 -524
- celldetective/gui/dynamic_progress.py +449 -0
- celldetective/gui/event_annotator.py +2023 -1662
- celldetective/gui/generic_signal_plot.py +1292 -944
- celldetective/gui/gui_utils.py +899 -1289
- celldetective/gui/interactions_block.py +658 -0
- celldetective/gui/interactive_timeseries_viewer.py +447 -0
- celldetective/gui/json_readers.py +48 -15
- celldetective/gui/layouts/__init__.py +5 -0
- celldetective/gui/layouts/background_model_free_layout.py +537 -0
- celldetective/gui/layouts/channel_offset_layout.py +134 -0
- celldetective/gui/layouts/local_correction_layout.py +91 -0
- celldetective/gui/layouts/model_fit_layout.py +372 -0
- celldetective/gui/layouts/operation_layout.py +68 -0
- celldetective/gui/layouts/protocol_designer_layout.py +96 -0
- celldetective/gui/pair_event_annotator.py +3130 -2435
- celldetective/gui/plot_measurements.py +586 -267
- celldetective/gui/plot_signals_ui.py +724 -506
- celldetective/gui/preprocessing_block.py +395 -0
- celldetective/gui/process_block.py +1678 -1831
- celldetective/gui/seg_model_loader.py +580 -473
- celldetective/gui/settings/__init__.py +0 -7
- celldetective/gui/settings/_cellpose_model_params.py +181 -0
- celldetective/gui/settings/_event_detection_model_params.py +95 -0
- celldetective/gui/settings/_segmentation_model_params.py +159 -0
- celldetective/gui/settings/_settings_base.py +77 -65
- celldetective/gui/settings/_settings_event_model_training.py +752 -526
- celldetective/gui/settings/_settings_measurements.py +1133 -964
- celldetective/gui/settings/_settings_neighborhood.py +574 -488
- celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
- celldetective/gui/settings/_settings_signal_annotator.py +329 -305
- celldetective/gui/settings/_settings_tracking.py +1304 -1094
- celldetective/gui/settings/_stardist_model_params.py +98 -0
- celldetective/gui/survival_ui.py +422 -312
- celldetective/gui/tableUI.py +1665 -1701
- celldetective/gui/table_ops/_maths.py +295 -0
- celldetective/gui/table_ops/_merge_groups.py +140 -0
- celldetective/gui/table_ops/_merge_one_hot.py +95 -0
- celldetective/gui/table_ops/_query_table.py +43 -0
- celldetective/gui/table_ops/_rename_col.py +44 -0
- celldetective/gui/thresholds_gui.py +382 -179
- celldetective/gui/viewers/__init__.py +0 -0
- celldetective/gui/viewers/base_viewer.py +700 -0
- celldetective/gui/viewers/channel_offset_viewer.py +331 -0
- celldetective/gui/viewers/contour_viewer.py +394 -0
- celldetective/gui/viewers/size_viewer.py +153 -0
- celldetective/gui/viewers/spot_detection_viewer.py +341 -0
- celldetective/gui/viewers/threshold_viewer.py +309 -0
- celldetective/gui/workers.py +403 -126
- celldetective/log_manager.py +92 -0
- celldetective/measure.py +1895 -1478
- celldetective/napari/__init__.py +0 -0
- celldetective/napari/utils.py +1025 -0
- celldetective/neighborhood.py +1914 -1448
- celldetective/preprocessing.py +1620 -1220
- celldetective/processes/__init__.py +0 -0
- celldetective/processes/background_correction.py +271 -0
- celldetective/processes/compute_neighborhood.py +894 -0
- celldetective/processes/detect_events.py +246 -0
- celldetective/processes/downloader.py +137 -0
- celldetective/processes/measure_cells.py +565 -0
- celldetective/processes/segment_cells.py +760 -0
- celldetective/processes/track_cells.py +435 -0
- celldetective/processes/train_segmentation_model.py +694 -0
- celldetective/processes/train_signal_model.py +265 -0
- celldetective/processes/unified_process.py +292 -0
- celldetective/regionprops/_regionprops.py +358 -317
- celldetective/relative_measurements.py +987 -710
- celldetective/scripts/measure_cells.py +313 -212
- celldetective/scripts/measure_relative.py +90 -46
- celldetective/scripts/segment_cells.py +165 -104
- celldetective/scripts/segment_cells_thresholds.py +96 -68
- celldetective/scripts/track_cells.py +198 -149
- celldetective/scripts/train_segmentation_model.py +324 -201
- celldetective/scripts/train_signal_model.py +87 -45
- celldetective/segmentation.py +844 -749
- celldetective/signals.py +3514 -2861
- celldetective/tracking.py +30 -15
- celldetective/utils/__init__.py +0 -0
- celldetective/utils/cellpose_utils/__init__.py +133 -0
- celldetective/utils/color_mappings.py +42 -0
- celldetective/utils/data_cleaning.py +630 -0
- celldetective/utils/data_loaders.py +450 -0
- celldetective/utils/dataset_helpers.py +207 -0
- celldetective/utils/downloaders.py +235 -0
- celldetective/utils/event_detection/__init__.py +8 -0
- celldetective/utils/experiment.py +1782 -0
- celldetective/utils/image_augmenters.py +308 -0
- celldetective/utils/image_cleaning.py +74 -0
- celldetective/utils/image_loaders.py +926 -0
- celldetective/utils/image_transforms.py +335 -0
- celldetective/utils/io.py +62 -0
- celldetective/utils/mask_cleaning.py +348 -0
- celldetective/utils/mask_transforms.py +5 -0
- celldetective/utils/masks.py +184 -0
- celldetective/utils/maths.py +351 -0
- celldetective/utils/model_getters.py +325 -0
- celldetective/utils/model_loaders.py +296 -0
- celldetective/utils/normalization.py +380 -0
- celldetective/utils/parsing.py +465 -0
- celldetective/utils/plots/__init__.py +0 -0
- celldetective/utils/plots/regression.py +53 -0
- celldetective/utils/resources.py +34 -0
- celldetective/utils/stardist_utils/__init__.py +104 -0
- celldetective/utils/stats.py +90 -0
- celldetective/utils/types.py +21 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
- celldetective-1.5.0b1.dist-info/RECORD +187 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
- tests/gui/test_new_project.py +129 -117
- tests/gui/test_project.py +127 -79
- tests/test_filters.py +39 -15
- tests/test_notebooks.py +8 -0
- tests/test_tracking.py +232 -13
- tests/test_utils.py +123 -77
- celldetective/gui/base_components.py +0 -23
- celldetective/gui/layouts.py +0 -1602
- celldetective/gui/processes/compute_neighborhood.py +0 -594
- celldetective/gui/processes/downloader.py +0 -111
- celldetective/gui/processes/measure_cells.py +0 -360
- celldetective/gui/processes/segment_cells.py +0 -499
- celldetective/gui/processes/track_cells.py +0 -303
- celldetective/gui/processes/train_segmentation_model.py +0 -270
- celldetective/gui/processes/train_signal_model.py +0 -108
- celldetective/gui/table_ops/merge_groups.py +0 -118
- celldetective/gui/viewers.py +0 -1354
- celldetective/io.py +0 -3663
- celldetective/utils.py +0 -3108
- celldetective-1.4.2.dist-info/RECORD +0 -123
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,694 @@
|
|
|
1
|
+
from distutils.dir_util import copy_tree
|
|
2
|
+
from multiprocessing import Process
|
|
3
|
+
import time
|
|
4
|
+
import os
|
|
5
|
+
import shutil
|
|
6
|
+
from glob import glob
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
|
|
11
|
+
from tensorflow.python.keras.callbacks import Callback
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
import numpy as np
|
|
14
|
+
import random
|
|
15
|
+
|
|
16
|
+
from celldetective.utils.image_augmenters import augmenter
|
|
17
|
+
from celldetective.utils.image_loaders import load_image_dataset
|
|
18
|
+
from celldetective.utils.image_cleaning import interpolate_nan
|
|
19
|
+
from celldetective.utils.normalization import normalize_multichannel
|
|
20
|
+
from celldetective.utils.mask_cleaning import fill_label_holes
|
|
21
|
+
from art import tprint
|
|
22
|
+
from csbdeep.utils import save_json
|
|
23
|
+
from celldetective import get_logger
|
|
24
|
+
|
|
25
|
+
logger = get_logger()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ProgressCallback(Callback):
|
|
29
|
+
|
|
30
|
+
def __init__(self, queue=None, epochs=100, stop_event=None):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.queue = queue
|
|
33
|
+
self.epochs = epochs
|
|
34
|
+
self.stop_event = stop_event
|
|
35
|
+
self.t0 = time.time()
|
|
36
|
+
|
|
37
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
38
|
+
|
|
39
|
+
if self.stop_event and self.stop_event.is_set():
|
|
40
|
+
self.model.stop_training = True
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
if logs is None:
|
|
44
|
+
logs = {}
|
|
45
|
+
|
|
46
|
+
# Send signal for progress bar
|
|
47
|
+
sum_done = (epoch + 1) / self.epochs * 100
|
|
48
|
+
mean_exec_per_step = (time.time() - self.t0) / (epoch + 1)
|
|
49
|
+
pred_time = (self.epochs - (epoch + 1)) * mean_exec_per_step
|
|
50
|
+
if self.queue is not None:
|
|
51
|
+
self.queue.put([sum_done, pred_time])
|
|
52
|
+
|
|
53
|
+
# Plot update
|
|
54
|
+
metrics = {k: v for k, v in logs.items() if not k.startswith("val_")}
|
|
55
|
+
val_metrics = {k: v for k, v in logs.items() if k.startswith("val_")}
|
|
56
|
+
|
|
57
|
+
plot_data = {
|
|
58
|
+
"epoch": epoch,
|
|
59
|
+
"metrics": metrics,
|
|
60
|
+
"val_metrics": val_metrics,
|
|
61
|
+
"model_name": "StarDist",
|
|
62
|
+
"total_epochs": self.epochs,
|
|
63
|
+
}
|
|
64
|
+
self.queue.put({"plot_data": plot_data})
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class QueueLoggingHandler(logging.Handler):
|
|
68
|
+
def __init__(self, queue, total_epochs, stop_event=None):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.queue = queue
|
|
71
|
+
self.total_epochs = total_epochs
|
|
72
|
+
self.stop_event = stop_event
|
|
73
|
+
self.epoch_pattern = re.compile(
|
|
74
|
+
r"Epoch (\d+), Time .*, Loss ([\d\.eE\-\+naninf]+)(?:, Loss Test ([\d\.eE\-\+naninf]+))?",
|
|
75
|
+
re.IGNORECASE,
|
|
76
|
+
)
|
|
77
|
+
self.t0 = time.time()
|
|
78
|
+
|
|
79
|
+
def emit(self, record):
|
|
80
|
+
if self.stop_event and self.stop_event.is_set():
|
|
81
|
+
# Can't easily stop cellpose_utils loop from here without raising exception or hacking
|
|
82
|
+
# raising exception might be safest to exit training loop
|
|
83
|
+
raise InterruptedError("Training interrupted")
|
|
84
|
+
|
|
85
|
+
log_entry = self.format(record)
|
|
86
|
+
match = self.epoch_pattern.search(log_entry)
|
|
87
|
+
if match:
|
|
88
|
+
epoch = int(match.group(1))
|
|
89
|
+
loss = float(match.group(2))
|
|
90
|
+
val_loss = float(match.group(3)) if match.group(3) else None
|
|
91
|
+
|
|
92
|
+
sum_done = (epoch + 1) / self.total_epochs * 100
|
|
93
|
+
mean_exec_per_step = (time.time() - self.t0) / (epoch + 1)
|
|
94
|
+
pred_time = (self.total_epochs - (epoch + 1)) * mean_exec_per_step
|
|
95
|
+
|
|
96
|
+
self.queue.put([sum_done, pred_time])
|
|
97
|
+
|
|
98
|
+
metrics = {"loss": loss}
|
|
99
|
+
val_metrics = {}
|
|
100
|
+
if val_loss is not None:
|
|
101
|
+
val_metrics["val_loss"] = val_loss
|
|
102
|
+
|
|
103
|
+
plot_data = {
|
|
104
|
+
"epoch": epoch,
|
|
105
|
+
"metrics": metrics,
|
|
106
|
+
"val_metrics": val_metrics,
|
|
107
|
+
"model_name": "Cellpose",
|
|
108
|
+
"total_epochs": self.total_epochs,
|
|
109
|
+
}
|
|
110
|
+
self.queue.put({"plot_data": plot_data})
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class TrainSegModelProcess(Process):
|
|
114
|
+
|
|
115
|
+
def __init__(self, queue=None, process_args=None, *args, **kwargs):
|
|
116
|
+
|
|
117
|
+
super().__init__(*args, **kwargs)
|
|
118
|
+
|
|
119
|
+
self.queue = queue
|
|
120
|
+
|
|
121
|
+
if process_args is not None:
|
|
122
|
+
for key, value in process_args.items():
|
|
123
|
+
setattr(self, key, value)
|
|
124
|
+
|
|
125
|
+
tprint("Train segmentation")
|
|
126
|
+
self.read_instructions()
|
|
127
|
+
self.extract_training_params()
|
|
128
|
+
self.load_dataset()
|
|
129
|
+
self.split_test_train()
|
|
130
|
+
|
|
131
|
+
self.sum_done = 0
|
|
132
|
+
self.t0 = time.time()
|
|
133
|
+
|
|
134
|
+
def read_instructions(self):
|
|
135
|
+
|
|
136
|
+
if os.path.exists(self.instructions):
|
|
137
|
+
with open(self.instructions, "r") as f:
|
|
138
|
+
self.training_instructions = json.load(f)
|
|
139
|
+
else:
|
|
140
|
+
logger.error("Training instructions could not be found. Abort.")
|
|
141
|
+
self.abort_process()
|
|
142
|
+
|
|
143
|
+
def run(self):
|
|
144
|
+
|
|
145
|
+
self.queue.put("Loading dataset...")
|
|
146
|
+
|
|
147
|
+
if self.model_type == "cellpose":
|
|
148
|
+
self.train_cellpose_model()
|
|
149
|
+
elif self.model_type == "stardist":
|
|
150
|
+
self.train_stardist_model()
|
|
151
|
+
|
|
152
|
+
self.queue.put("finished")
|
|
153
|
+
self.queue.close()
|
|
154
|
+
|
|
155
|
+
def train_stardist_model(self):
|
|
156
|
+
|
|
157
|
+
from stardist import calculate_extents, gputools_available
|
|
158
|
+
from stardist.models import Config2D, StarDist2D
|
|
159
|
+
|
|
160
|
+
n_rays = 32
|
|
161
|
+
logger.info(gputools_available())
|
|
162
|
+
|
|
163
|
+
n_channel = self.X_trn[0].shape[-1]
|
|
164
|
+
|
|
165
|
+
# Predict on subsampled grid for increased efficiency and larger field of view
|
|
166
|
+
grid = (2, 2)
|
|
167
|
+
conf = Config2D(
|
|
168
|
+
n_rays=n_rays,
|
|
169
|
+
grid=grid,
|
|
170
|
+
use_gpu=self.use_gpu,
|
|
171
|
+
n_channel_in=n_channel,
|
|
172
|
+
train_learning_rate=self.learning_rate,
|
|
173
|
+
train_patch_size=(256, 256),
|
|
174
|
+
train_epochs=self.epochs,
|
|
175
|
+
train_reduce_lr={"factor": 0.1, "patience": 30, "min_delta": 0},
|
|
176
|
+
train_batch_size=self.batch_size,
|
|
177
|
+
train_steps_per_epoch=int(self.augmentation_factor * len(self.X_trn)),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if self.use_gpu:
|
|
181
|
+
from csbdeep.utils.tf import limit_gpu_memory
|
|
182
|
+
|
|
183
|
+
limit_gpu_memory(None, allow_growth=True)
|
|
184
|
+
|
|
185
|
+
if self.pretrained is None:
|
|
186
|
+
model = StarDist2D(
|
|
187
|
+
conf, name=self.model_name, basedir=self.target_directory
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
os.rename(
|
|
191
|
+
self.instructions,
|
|
192
|
+
os.sep.join([self.target_directory, self.model_name, "temp.json"]),
|
|
193
|
+
)
|
|
194
|
+
copy_tree(
|
|
195
|
+
self.pretrained, os.sep.join([self.target_directory, self.model_name])
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if os.path.exists(
|
|
199
|
+
os.sep.join(
|
|
200
|
+
[
|
|
201
|
+
self.target_directory,
|
|
202
|
+
self.model_name,
|
|
203
|
+
"training_instructions.json",
|
|
204
|
+
]
|
|
205
|
+
)
|
|
206
|
+
):
|
|
207
|
+
os.remove(
|
|
208
|
+
os.sep.join(
|
|
209
|
+
[
|
|
210
|
+
self.target_directory,
|
|
211
|
+
self.model_name,
|
|
212
|
+
"training_instructions.json",
|
|
213
|
+
]
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
if os.path.exists(
|
|
217
|
+
os.sep.join(
|
|
218
|
+
[self.target_directory, self.model_name, "config_input.json"]
|
|
219
|
+
)
|
|
220
|
+
):
|
|
221
|
+
os.remove(
|
|
222
|
+
os.sep.join(
|
|
223
|
+
[self.target_directory, self.model_name, "config_input.json"]
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
if os.path.exists(
|
|
227
|
+
os.sep.join([self.target_directory, self.model_name, "logs" + os.sep])
|
|
228
|
+
):
|
|
229
|
+
shutil.rmtree(
|
|
230
|
+
os.sep.join([self.target_directory, self.model_name, "logs"])
|
|
231
|
+
)
|
|
232
|
+
os.rename(
|
|
233
|
+
os.sep.join([self.target_directory, self.model_name, "temp.json"]),
|
|
234
|
+
os.sep.join(
|
|
235
|
+
[
|
|
236
|
+
self.target_directory,
|
|
237
|
+
self.model_name,
|
|
238
|
+
"training_instructions.json",
|
|
239
|
+
]
|
|
240
|
+
),
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
|
|
244
|
+
model = StarDist2D(
|
|
245
|
+
None, name=self.model_name, basedir=self.target_directory
|
|
246
|
+
)
|
|
247
|
+
model.config.train_epochs = self.epochs
|
|
248
|
+
model.config.train_batch_size = min(len(self.X_trn), self.batch_size)
|
|
249
|
+
model.config.train_learning_rate = (
|
|
250
|
+
self.learning_rate
|
|
251
|
+
) # perf seems bad if lr is changed in transfer
|
|
252
|
+
model.config.use_gpu = self.use_gpu
|
|
253
|
+
model.config.train_reduce_lr = {
|
|
254
|
+
"factor": 0.1,
|
|
255
|
+
"patience": 10,
|
|
256
|
+
"min_delta": 0,
|
|
257
|
+
}
|
|
258
|
+
logger.info(f"{model.config=}")
|
|
259
|
+
|
|
260
|
+
save_json(
|
|
261
|
+
vars(model.config),
|
|
262
|
+
os.sep.join([self.target_directory, self.model_name, "config.json"]),
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
if self.pretrained is not None:
|
|
266
|
+
logger.info("Freezing encoder layers for StarDist model...")
|
|
267
|
+
mod = model.keras_model
|
|
268
|
+
encoder_depth = len(mod.layers) // 2
|
|
269
|
+
|
|
270
|
+
for layer in mod.layers[:encoder_depth]:
|
|
271
|
+
layer.trainable = False
|
|
272
|
+
|
|
273
|
+
# Keep decoder trainable
|
|
274
|
+
for layer in mod.layers[encoder_depth:]:
|
|
275
|
+
layer.trainable = True
|
|
276
|
+
|
|
277
|
+
median_size = calculate_extents(list(self.Y_trn), np.mean)
|
|
278
|
+
fov = np.array(model._axes_tile_overlap("YX"))
|
|
279
|
+
logger.info(f"median object size: {median_size}")
|
|
280
|
+
logger.info(f"network field of view : {fov}")
|
|
281
|
+
if any(median_size > fov):
|
|
282
|
+
logger.warning(
|
|
283
|
+
"WARNING: median object size larger than field of view of the neural network."
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
import sys
|
|
287
|
+
|
|
288
|
+
class StreamToQueue:
|
|
289
|
+
def __init__(self, queue, total_epochs, original_stream, stop_event=None):
|
|
290
|
+
self.queue = queue
|
|
291
|
+
self.total_epochs = total_epochs
|
|
292
|
+
self.original_stream = original_stream
|
|
293
|
+
self.stop_event = stop_event
|
|
294
|
+
self.epoch_pattern = re.compile(r"Epoch (\d+)/(\d+)")
|
|
295
|
+
# Generic pattern to capture "key: value" pairs
|
|
296
|
+
self.metric_pattern = re.compile(
|
|
297
|
+
r"([\w_]+)\s*:\s*([\d\.eE\-\+naninf]+)"
|
|
298
|
+
)
|
|
299
|
+
self.current_epoch = 0
|
|
300
|
+
self.t0 = time.time()
|
|
301
|
+
self.buffer = ""
|
|
302
|
+
|
|
303
|
+
def write(self, message):
|
|
304
|
+
if self.stop_event and self.stop_event.is_set():
|
|
305
|
+
raise InterruptedError("Training interrupted by user")
|
|
306
|
+
|
|
307
|
+
self.original_stream.write(message)
|
|
308
|
+
self.original_stream.flush() # Ensure immediate display
|
|
309
|
+
self.buffer += message
|
|
310
|
+
if "\n" in message or "\r" in message:
|
|
311
|
+
self._parse_buffer()
|
|
312
|
+
|
|
313
|
+
def flush(self):
|
|
314
|
+
self.original_stream.flush()
|
|
315
|
+
|
|
316
|
+
def _parse_buffer(self):
|
|
317
|
+
lines = re.split(r"[\r\n]+", self.buffer)
|
|
318
|
+
# Keep the last incomplete part in buffer
|
|
319
|
+
if not (self.buffer.endswith("\n") or self.buffer.endswith("\r")):
|
|
320
|
+
self.buffer = lines[-1]
|
|
321
|
+
lines = lines[:-1]
|
|
322
|
+
else:
|
|
323
|
+
self.buffer = ""
|
|
324
|
+
|
|
325
|
+
for line in lines:
|
|
326
|
+
if not line.strip():
|
|
327
|
+
continue
|
|
328
|
+
|
|
329
|
+
# Check for Epoch
|
|
330
|
+
m_epoch = self.epoch_pattern.search(line)
|
|
331
|
+
if m_epoch:
|
|
332
|
+
self.current_epoch = int(m_epoch.group(1))
|
|
333
|
+
# Put progress?
|
|
334
|
+
sum_done = (self.current_epoch - 1) / self.total_epochs * 100
|
|
335
|
+
self.queue.put(
|
|
336
|
+
[sum_done, 0]
|
|
337
|
+
) # Time estimation handled by GUI or ignored
|
|
338
|
+
continue
|
|
339
|
+
|
|
340
|
+
# Capture all metrics in the line
|
|
341
|
+
found_metrics = self.metric_pattern.findall(line)
|
|
342
|
+
if found_metrics:
|
|
343
|
+
metrics = {}
|
|
344
|
+
val_metrics = {}
|
|
345
|
+
|
|
346
|
+
for key, val_str in found_metrics:
|
|
347
|
+
try:
|
|
348
|
+
val = float(val_str)
|
|
349
|
+
if key.startswith("val_"):
|
|
350
|
+
val_metrics[key] = val
|
|
351
|
+
else:
|
|
352
|
+
metrics[key] = val
|
|
353
|
+
except ValueError:
|
|
354
|
+
pass
|
|
355
|
+
|
|
356
|
+
# Only send plot data if we have validation metrics (indicates end of epoch)
|
|
357
|
+
if metrics and val_metrics:
|
|
358
|
+
plot_data = {
|
|
359
|
+
"epoch": self.current_epoch,
|
|
360
|
+
"metrics": metrics,
|
|
361
|
+
"val_metrics": val_metrics,
|
|
362
|
+
"model_name": "StarDist",
|
|
363
|
+
"total_epochs": self.total_epochs,
|
|
364
|
+
}
|
|
365
|
+
self.queue.put({"plot_data": plot_data})
|
|
366
|
+
|
|
367
|
+
# Redirect stdout/stderr to capture Keras output
|
|
368
|
+
original_stdout = sys.stdout
|
|
369
|
+
original_stderr = sys.stderr
|
|
370
|
+
stream_parser = StreamToQueue(
|
|
371
|
+
self.queue,
|
|
372
|
+
self.epochs,
|
|
373
|
+
original_stdout,
|
|
374
|
+
stop_event=self.stop_event if hasattr(self, "stop_event") else None,
|
|
375
|
+
)
|
|
376
|
+
sys.stdout = stream_parser
|
|
377
|
+
sys.stderr = stream_parser # Keras often prints to stderr
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
if self.augmentation_factor == 1.0:
|
|
381
|
+
model.train(
|
|
382
|
+
self.X_trn,
|
|
383
|
+
self.Y_trn,
|
|
384
|
+
validation_data=(self.X_val, self.Y_val),
|
|
385
|
+
epochs=self.epochs,
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
model.train(
|
|
389
|
+
self.X_trn,
|
|
390
|
+
self.Y_trn,
|
|
391
|
+
validation_data=(self.X_val, self.Y_val),
|
|
392
|
+
augmenter=augmenter,
|
|
393
|
+
epochs=self.epochs,
|
|
394
|
+
)
|
|
395
|
+
except Exception as e:
|
|
396
|
+
logger.error(f"Error in StarDist training: {e}")
|
|
397
|
+
raise e
|
|
398
|
+
finally:
|
|
399
|
+
sys.stdout = original_stdout
|
|
400
|
+
sys.stderr = original_stderr
|
|
401
|
+
|
|
402
|
+
model.optimize_thresholds(self.X_val, self.Y_val)
|
|
403
|
+
|
|
404
|
+
if isinstance(median_size, (list, np.ndarray)):
|
|
405
|
+
median_size_scalar = np.mean(median_size)
|
|
406
|
+
else:
|
|
407
|
+
median_size_scalar = median_size
|
|
408
|
+
|
|
409
|
+
config_inputs = {
|
|
410
|
+
"channels": self.target_channels,
|
|
411
|
+
"normalization_percentile": self.normalization_percentile,
|
|
412
|
+
"normalization_clip": self.normalization_clip,
|
|
413
|
+
"normalization_values": self.normalization_values,
|
|
414
|
+
"model_type": "stardist",
|
|
415
|
+
"spatial_calibration": self.spatial_calibration,
|
|
416
|
+
"cell_size_um": float(median_size_scalar * self.spatial_calibration),
|
|
417
|
+
"dataset": {"train": self.files_train, "validation": self.files_val},
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
def make_json_safe(obj):
|
|
421
|
+
if isinstance(obj, np.ndarray):
|
|
422
|
+
return obj.tolist()
|
|
423
|
+
if isinstance(obj, (np.int64, np.int32)):
|
|
424
|
+
return int(obj)
|
|
425
|
+
if isinstance(obj, (np.float32, np.float64)):
|
|
426
|
+
return float(obj)
|
|
427
|
+
return str(obj)
|
|
428
|
+
|
|
429
|
+
json_input_config = json.dumps(config_inputs, indent=4, default=make_json_safe)
|
|
430
|
+
with open(
|
|
431
|
+
os.sep.join([self.target_directory, self.model_name, "config_input.json"]),
|
|
432
|
+
"w",
|
|
433
|
+
) as outfile:
|
|
434
|
+
outfile.write(json_input_config)
|
|
435
|
+
|
|
436
|
+
def train_cellpose_model(self):
|
|
437
|
+
|
|
438
|
+
# do augmentation in place
|
|
439
|
+
X_aug = []
|
|
440
|
+
Y_aug = []
|
|
441
|
+
n_val = max(1, int(round(self.augmentation_factor * len(self.X_trn))))
|
|
442
|
+
indices = random.choices(list(np.arange(len(self.X_trn))), k=n_val)
|
|
443
|
+
logger.info("Performing image augmentation pre-training...")
|
|
444
|
+
for i in tqdm(indices):
|
|
445
|
+
x_aug, y_aug = augmenter(self.X_trn[i], self.Y_trn[i])
|
|
446
|
+
X_aug.append(x_aug)
|
|
447
|
+
Y_aug.append(y_aug)
|
|
448
|
+
|
|
449
|
+
# Channel axis in front for cellpose_utils
|
|
450
|
+
X_aug = [np.moveaxis(x, -1, 0) for x in X_aug]
|
|
451
|
+
self.X_val = [np.moveaxis(x, -1, 0) for x in self.X_val]
|
|
452
|
+
logger.info("number of augmented images: %3d" % len(X_aug))
|
|
453
|
+
|
|
454
|
+
from cellpose.models import CellposeModel
|
|
455
|
+
from cellpose.io import logger_setup
|
|
456
|
+
import torch
|
|
457
|
+
|
|
458
|
+
if not self.use_gpu:
|
|
459
|
+
logger.info("Using CPU for training...")
|
|
460
|
+
device = torch.device("cpu")
|
|
461
|
+
else:
|
|
462
|
+
logger.info("Using GPU for training...")
|
|
463
|
+
|
|
464
|
+
# logger_setup configures console and file handlers for cellpose_utils
|
|
465
|
+
_, log_file = logger_setup()
|
|
466
|
+
|
|
467
|
+
# Get cellpose_utils logger explicitly to ensure we catch all cellpose_utils logs (e.g. from models)
|
|
468
|
+
logger_cellpose = logging.getLogger("cellpose")
|
|
469
|
+
|
|
470
|
+
# Add custom handler
|
|
471
|
+
handler = QueueLoggingHandler(
|
|
472
|
+
self.queue,
|
|
473
|
+
self.epochs,
|
|
474
|
+
stop_event=self.stop_event if hasattr(self, "stop_event") else None,
|
|
475
|
+
)
|
|
476
|
+
handler.setLevel(logging.INFO)
|
|
477
|
+
logger_cellpose.addHandler(handler)
|
|
478
|
+
|
|
479
|
+
try:
|
|
480
|
+
logger.info(f"Pretrained model: {self.pretrained}")
|
|
481
|
+
if self.pretrained is not None:
|
|
482
|
+
pretrained_path = os.sep.join(
|
|
483
|
+
[self.pretrained, os.path.split(self.pretrained)[-1]]
|
|
484
|
+
)
|
|
485
|
+
else:
|
|
486
|
+
pretrained_path = self.pretrained
|
|
487
|
+
|
|
488
|
+
model = CellposeModel(
|
|
489
|
+
gpu=self.use_gpu,
|
|
490
|
+
model_type=None,
|
|
491
|
+
pretrained_model=pretrained_path,
|
|
492
|
+
diam_mean=30.0,
|
|
493
|
+
nchan=X_aug[0].shape[0],
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
if self.pretrained is not None:
|
|
497
|
+
logger.info("Freezing encoder layers for Cellpose model...")
|
|
498
|
+
for param in model.net.downsample.parameters():
|
|
499
|
+
param.requires_grad = False
|
|
500
|
+
|
|
501
|
+
# Optional: freeze style branch
|
|
502
|
+
for param in model.net.make_style.parameters():
|
|
503
|
+
param.requires_grad = False
|
|
504
|
+
|
|
505
|
+
# Keep decoder trainable
|
|
506
|
+
for param in model.net.upsample.parameters():
|
|
507
|
+
param.requires_grad = True
|
|
508
|
+
|
|
509
|
+
# Keep output head trainable
|
|
510
|
+
for param in model.net.output.parameters():
|
|
511
|
+
param.requires_grad = True
|
|
512
|
+
|
|
513
|
+
# Unfreeze all output heads (version-safe)
|
|
514
|
+
output_heads = ["output", "output_conv", "flow", "prob"]
|
|
515
|
+
for head_name in output_heads:
|
|
516
|
+
if hasattr(model.net, head_name):
|
|
517
|
+
for param in getattr(model.net, head_name).parameters():
|
|
518
|
+
param.requires_grad = True
|
|
519
|
+
|
|
520
|
+
model.train(
|
|
521
|
+
train_data=X_aug,
|
|
522
|
+
train_labels=Y_aug,
|
|
523
|
+
normalize=False,
|
|
524
|
+
channels=None,
|
|
525
|
+
batch_size=self.batch_size,
|
|
526
|
+
min_train_masks=1,
|
|
527
|
+
save_path=self.target_directory + os.sep + self.model_name,
|
|
528
|
+
n_epochs=self.epochs,
|
|
529
|
+
model_name=self.model_name,
|
|
530
|
+
learning_rate=self.learning_rate,
|
|
531
|
+
test_data=self.X_val,
|
|
532
|
+
test_labels=self.Y_val,
|
|
533
|
+
)
|
|
534
|
+
except InterruptedError:
|
|
535
|
+
logger.info("Training interrupted.")
|
|
536
|
+
except Exception as e:
|
|
537
|
+
logger.error(f"Error during training: {e}")
|
|
538
|
+
raise e
|
|
539
|
+
finally:
|
|
540
|
+
logger_cellpose.removeHandler(handler)
|
|
541
|
+
|
|
542
|
+
file_to_move = glob(
|
|
543
|
+
os.sep.join([self.target_directory, self.model_name, "models", "*"])
|
|
544
|
+
)[0]
|
|
545
|
+
shutil.move(
|
|
546
|
+
file_to_move,
|
|
547
|
+
os.sep.join([self.target_directory, self.model_name, ""])
|
|
548
|
+
+ os.path.split(file_to_move)[-1],
|
|
549
|
+
)
|
|
550
|
+
os.rmdir(os.sep.join([self.target_directory, self.model_name, "models"]))
|
|
551
|
+
|
|
552
|
+
diameter = model.diam_labels
|
|
553
|
+
|
|
554
|
+
if (
|
|
555
|
+
self.pretrained is not None
|
|
556
|
+
and os.path.split(self.pretrained)[-1] == "CP_nuclei"
|
|
557
|
+
):
|
|
558
|
+
standard_diameter = 17.0
|
|
559
|
+
else:
|
|
560
|
+
standard_diameter = 30.0
|
|
561
|
+
|
|
562
|
+
input_spatial_calibration = (
|
|
563
|
+
self.spatial_calibration
|
|
564
|
+
) # *diameter / standard_diameter
|
|
565
|
+
|
|
566
|
+
config_inputs = {
|
|
567
|
+
"channels": self.target_channels,
|
|
568
|
+
"diameter": standard_diameter,
|
|
569
|
+
"cellprob_threshold": 0.0,
|
|
570
|
+
"flow_threshold": 0.4,
|
|
571
|
+
"normalization_percentile": self.normalization_percentile,
|
|
572
|
+
"normalization_clip": self.normalization_clip,
|
|
573
|
+
"normalization_values": self.normalization_values,
|
|
574
|
+
"model_type": "cellpose",
|
|
575
|
+
"spatial_calibration": input_spatial_calibration,
|
|
576
|
+
"cell_size_um": round(diameter * input_spatial_calibration, 4),
|
|
577
|
+
"dataset": {"train": self.files_train, "validation": self.files_val},
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
def make_json_safe(obj):
|
|
581
|
+
if isinstance(obj, np.ndarray):
|
|
582
|
+
return obj.tolist()
|
|
583
|
+
if isinstance(obj, (np.int64, np.int32)):
|
|
584
|
+
return int(obj)
|
|
585
|
+
if isinstance(obj, (np.float32, np.float64)):
|
|
586
|
+
return float(obj)
|
|
587
|
+
return str(obj)
|
|
588
|
+
|
|
589
|
+
json_input_config = json.dumps(config_inputs, indent=4, default=make_json_safe)
|
|
590
|
+
with open(
|
|
591
|
+
os.sep.join([self.target_directory, self.model_name, "config_input.json"]),
|
|
592
|
+
"w",
|
|
593
|
+
) as outfile:
|
|
594
|
+
outfile.write(json_input_config)
|
|
595
|
+
|
|
596
|
+
def split_test_train(self):
|
|
597
|
+
|
|
598
|
+
if not len(self.X) > 1:
|
|
599
|
+
logger.error("Not enough training data")
|
|
600
|
+
self.abort_process()
|
|
601
|
+
|
|
602
|
+
rng = np.random.RandomState()
|
|
603
|
+
ind = rng.permutation(len(self.X))
|
|
604
|
+
n_val = max(1, int(round(self.validation_split * len(ind))))
|
|
605
|
+
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
|
|
606
|
+
self.X_val, self.Y_val = [self.X[i] for i in ind_val], [
|
|
607
|
+
self.Y[i] for i in ind_val
|
|
608
|
+
]
|
|
609
|
+
self.X_trn, self.Y_trn = [self.X[i] for i in ind_train], [
|
|
610
|
+
self.Y[i] for i in ind_train
|
|
611
|
+
]
|
|
612
|
+
|
|
613
|
+
self.files_train = [self.filenames[i] for i in ind_train]
|
|
614
|
+
self.files_val = [self.filenames[i] for i in ind_val]
|
|
615
|
+
|
|
616
|
+
logger.info("number of images: %3d" % len(self.X))
|
|
617
|
+
logger.info("- training: %3d" % len(self.X_trn))
|
|
618
|
+
logger.info("- validation: %3d" % len(self.X_val))
|
|
619
|
+
|
|
620
|
+
def extract_training_params(self):
|
|
621
|
+
|
|
622
|
+
self.model_name = self.training_instructions["model_name"]
|
|
623
|
+
self.target_directory = self.training_instructions["target_directory"]
|
|
624
|
+
self.model_type = self.training_instructions["model_type"]
|
|
625
|
+
self.pretrained = self.training_instructions["pretrained"]
|
|
626
|
+
|
|
627
|
+
self.datasets = self.training_instructions["ds"]
|
|
628
|
+
|
|
629
|
+
self.target_channels = self.training_instructions["channel_option"]
|
|
630
|
+
self.normalization_percentile = self.training_instructions[
|
|
631
|
+
"normalization_percentile"
|
|
632
|
+
]
|
|
633
|
+
self.normalization_clip = self.training_instructions["normalization_clip"]
|
|
634
|
+
self.normalization_values = self.training_instructions["normalization_values"]
|
|
635
|
+
self.spatial_calibration = self.training_instructions["spatial_calibration"]
|
|
636
|
+
|
|
637
|
+
self.validation_split = self.training_instructions["validation_split"]
|
|
638
|
+
self.augmentation_factor = self.training_instructions["augmentation_factor"]
|
|
639
|
+
|
|
640
|
+
self.learning_rate = self.training_instructions["learning_rate"]
|
|
641
|
+
self.epochs = self.training_instructions["epochs"]
|
|
642
|
+
self.batch_size = self.training_instructions["batch_size"]
|
|
643
|
+
|
|
644
|
+
def load_dataset(self):
|
|
645
|
+
|
|
646
|
+
logger.info(f"Datasets: {self.datasets}")
|
|
647
|
+
self.X, self.Y, self.filenames = load_image_dataset(
|
|
648
|
+
self.datasets,
|
|
649
|
+
self.target_channels,
|
|
650
|
+
train_spatial_calibration=self.spatial_calibration,
|
|
651
|
+
mask_suffix="labelled",
|
|
652
|
+
)
|
|
653
|
+
logger.info("Dataset loaded...")
|
|
654
|
+
|
|
655
|
+
self.values = []
|
|
656
|
+
self.percentiles = []
|
|
657
|
+
for k in range(len(self.normalization_percentile)):
|
|
658
|
+
if self.normalization_percentile[k]:
|
|
659
|
+
self.percentiles.append(self.normalization_values[k])
|
|
660
|
+
self.values.append(None)
|
|
661
|
+
else:
|
|
662
|
+
self.percentiles.append(None)
|
|
663
|
+
self.values.append(self.normalization_values[k])
|
|
664
|
+
|
|
665
|
+
self.X = [
|
|
666
|
+
normalize_multichannel(
|
|
667
|
+
x,
|
|
668
|
+
**{
|
|
669
|
+
"percentiles": self.percentiles,
|
|
670
|
+
"values": self.values,
|
|
671
|
+
"clip": self.normalization_clip,
|
|
672
|
+
},
|
|
673
|
+
)
|
|
674
|
+
for x in self.X
|
|
675
|
+
]
|
|
676
|
+
|
|
677
|
+
for k in range(len(self.X)):
|
|
678
|
+
x = self.X[k].copy()
|
|
679
|
+
x_interp = np.moveaxis(
|
|
680
|
+
[interpolate_nan(x[:, :, c].copy()) for c in range(x.shape[-1])], 0, -1
|
|
681
|
+
)
|
|
682
|
+
self.X[k] = x_interp
|
|
683
|
+
|
|
684
|
+
self.Y = [fill_label_holes(y) for y in tqdm(self.Y)]
|
|
685
|
+
|
|
686
|
+
def end_process(self):
|
|
687
|
+
|
|
688
|
+
self.terminate()
|
|
689
|
+
self.queue.put("finished")
|
|
690
|
+
|
|
691
|
+
def abort_process(self):
|
|
692
|
+
|
|
693
|
+
self.terminate()
|
|
694
|
+
self.queue.put("error")
|