celldetective 1.4.2__py3-none-any.whl → 1.5.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__init__.py +25 -0
- celldetective/__main__.py +62 -43
- celldetective/_version.py +1 -1
- celldetective/extra_properties.py +477 -399
- celldetective/filters.py +192 -97
- celldetective/gui/InitWindow.py +541 -411
- celldetective/gui/__init__.py +0 -15
- celldetective/gui/about.py +44 -39
- celldetective/gui/analyze_block.py +120 -84
- celldetective/gui/base/__init__.py +0 -0
- celldetective/gui/base/channel_norm_generator.py +335 -0
- celldetective/gui/base/components.py +249 -0
- celldetective/gui/base/feature_choice.py +92 -0
- celldetective/gui/base/figure_canvas.py +52 -0
- celldetective/gui/base/list_widget.py +133 -0
- celldetective/gui/{styles.py → base/styles.py} +92 -36
- celldetective/gui/base/utils.py +33 -0
- celldetective/gui/base_annotator.py +900 -767
- celldetective/gui/classifier_widget.py +6 -22
- celldetective/gui/configure_new_exp.py +777 -671
- celldetective/gui/control_panel.py +635 -524
- celldetective/gui/dynamic_progress.py +449 -0
- celldetective/gui/event_annotator.py +2023 -1662
- celldetective/gui/generic_signal_plot.py +1292 -944
- celldetective/gui/gui_utils.py +899 -1289
- celldetective/gui/interactions_block.py +658 -0
- celldetective/gui/interactive_timeseries_viewer.py +447 -0
- celldetective/gui/json_readers.py +48 -15
- celldetective/gui/layouts/__init__.py +5 -0
- celldetective/gui/layouts/background_model_free_layout.py +537 -0
- celldetective/gui/layouts/channel_offset_layout.py +134 -0
- celldetective/gui/layouts/local_correction_layout.py +91 -0
- celldetective/gui/layouts/model_fit_layout.py +372 -0
- celldetective/gui/layouts/operation_layout.py +68 -0
- celldetective/gui/layouts/protocol_designer_layout.py +96 -0
- celldetective/gui/pair_event_annotator.py +3130 -2435
- celldetective/gui/plot_measurements.py +586 -267
- celldetective/gui/plot_signals_ui.py +724 -506
- celldetective/gui/preprocessing_block.py +395 -0
- celldetective/gui/process_block.py +1678 -1831
- celldetective/gui/seg_model_loader.py +580 -473
- celldetective/gui/settings/__init__.py +0 -7
- celldetective/gui/settings/_cellpose_model_params.py +181 -0
- celldetective/gui/settings/_event_detection_model_params.py +95 -0
- celldetective/gui/settings/_segmentation_model_params.py +159 -0
- celldetective/gui/settings/_settings_base.py +77 -65
- celldetective/gui/settings/_settings_event_model_training.py +752 -526
- celldetective/gui/settings/_settings_measurements.py +1133 -964
- celldetective/gui/settings/_settings_neighborhood.py +574 -488
- celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
- celldetective/gui/settings/_settings_signal_annotator.py +329 -305
- celldetective/gui/settings/_settings_tracking.py +1304 -1094
- celldetective/gui/settings/_stardist_model_params.py +98 -0
- celldetective/gui/survival_ui.py +422 -312
- celldetective/gui/tableUI.py +1665 -1701
- celldetective/gui/table_ops/_maths.py +295 -0
- celldetective/gui/table_ops/_merge_groups.py +140 -0
- celldetective/gui/table_ops/_merge_one_hot.py +95 -0
- celldetective/gui/table_ops/_query_table.py +43 -0
- celldetective/gui/table_ops/_rename_col.py +44 -0
- celldetective/gui/thresholds_gui.py +382 -179
- celldetective/gui/viewers/__init__.py +0 -0
- celldetective/gui/viewers/base_viewer.py +700 -0
- celldetective/gui/viewers/channel_offset_viewer.py +331 -0
- celldetective/gui/viewers/contour_viewer.py +394 -0
- celldetective/gui/viewers/size_viewer.py +153 -0
- celldetective/gui/viewers/spot_detection_viewer.py +341 -0
- celldetective/gui/viewers/threshold_viewer.py +309 -0
- celldetective/gui/workers.py +304 -126
- celldetective/log_manager.py +92 -0
- celldetective/measure.py +1895 -1478
- celldetective/napari/__init__.py +0 -0
- celldetective/napari/utils.py +1025 -0
- celldetective/neighborhood.py +1914 -1448
- celldetective/preprocessing.py +1620 -1220
- celldetective/processes/__init__.py +0 -0
- celldetective/processes/background_correction.py +271 -0
- celldetective/processes/compute_neighborhood.py +894 -0
- celldetective/processes/detect_events.py +246 -0
- celldetective/processes/measure_cells.py +565 -0
- celldetective/processes/segment_cells.py +760 -0
- celldetective/processes/track_cells.py +435 -0
- celldetective/processes/train_segmentation_model.py +694 -0
- celldetective/processes/train_signal_model.py +265 -0
- celldetective/processes/unified_process.py +292 -0
- celldetective/regionprops/_regionprops.py +358 -317
- celldetective/relative_measurements.py +987 -710
- celldetective/scripts/measure_cells.py +313 -212
- celldetective/scripts/measure_relative.py +90 -46
- celldetective/scripts/segment_cells.py +165 -104
- celldetective/scripts/segment_cells_thresholds.py +96 -68
- celldetective/scripts/track_cells.py +198 -149
- celldetective/scripts/train_segmentation_model.py +324 -201
- celldetective/scripts/train_signal_model.py +87 -45
- celldetective/segmentation.py +844 -749
- celldetective/signals.py +3514 -2861
- celldetective/tracking.py +30 -15
- celldetective/utils/__init__.py +0 -0
- celldetective/utils/cellpose_utils/__init__.py +133 -0
- celldetective/utils/color_mappings.py +42 -0
- celldetective/utils/data_cleaning.py +630 -0
- celldetective/utils/data_loaders.py +450 -0
- celldetective/utils/dataset_helpers.py +207 -0
- celldetective/utils/downloaders.py +197 -0
- celldetective/utils/event_detection/__init__.py +8 -0
- celldetective/utils/experiment.py +1782 -0
- celldetective/utils/image_augmenters.py +308 -0
- celldetective/utils/image_cleaning.py +74 -0
- celldetective/utils/image_loaders.py +926 -0
- celldetective/utils/image_transforms.py +335 -0
- celldetective/utils/io.py +62 -0
- celldetective/utils/mask_cleaning.py +348 -0
- celldetective/utils/mask_transforms.py +5 -0
- celldetective/utils/masks.py +184 -0
- celldetective/utils/maths.py +351 -0
- celldetective/utils/model_getters.py +325 -0
- celldetective/utils/model_loaders.py +296 -0
- celldetective/utils/normalization.py +380 -0
- celldetective/utils/parsing.py +465 -0
- celldetective/utils/plots/__init__.py +0 -0
- celldetective/utils/plots/regression.py +53 -0
- celldetective/utils/resources.py +34 -0
- celldetective/utils/stardist_utils/__init__.py +104 -0
- celldetective/utils/stats.py +90 -0
- celldetective/utils/types.py +21 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/METADATA +1 -1
- celldetective-1.5.0b0.dist-info/RECORD +187 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/WHEEL +1 -1
- tests/gui/test_new_project.py +129 -117
- tests/gui/test_project.py +127 -79
- tests/test_filters.py +39 -15
- tests/test_notebooks.py +8 -0
- tests/test_tracking.py +232 -13
- tests/test_utils.py +123 -77
- celldetective/gui/base_components.py +0 -23
- celldetective/gui/layouts.py +0 -1602
- celldetective/gui/processes/compute_neighborhood.py +0 -594
- celldetective/gui/processes/measure_cells.py +0 -360
- celldetective/gui/processes/segment_cells.py +0 -499
- celldetective/gui/processes/track_cells.py +0 -303
- celldetective/gui/processes/train_segmentation_model.py +0 -270
- celldetective/gui/processes/train_signal_model.py +0 -108
- celldetective/gui/table_ops/merge_groups.py +0 -118
- celldetective/gui/viewers.py +0 -1354
- celldetective/io.py +0 -3663
- celldetective/utils.py +0 -3108
- celldetective-1.4.2.dist-info/RECORD +0 -123
- /celldetective/{gui/processes → processes}/downloader.py +0 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/entry_points.txt +0 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/licenses/LICENSE +0 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
from multiprocessing import Process
|
|
2
|
+
import time
|
|
3
|
+
import os
|
|
4
|
+
import json
|
|
5
|
+
from glob import glob
|
|
6
|
+
import numpy as np
|
|
7
|
+
from art import tprint
|
|
8
|
+
from tensorflow.python.keras.callbacks import Callback
|
|
9
|
+
|
|
10
|
+
from celldetective.signals import SignalDetectionModel
|
|
11
|
+
from celldetective.log_manager import get_logger
|
|
12
|
+
from celldetective.utils.model_loaders import locate_signal_model
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ProgressCallback(Callback):
|
|
18
|
+
|
|
19
|
+
def __init__(self, queue=None, total_epochs=100, stop_event=None):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.queue = queue
|
|
22
|
+
self.total_epochs = total_epochs
|
|
23
|
+
self.current_step = 0
|
|
24
|
+
self.t0 = time.time()
|
|
25
|
+
self.stop_event = stop_event
|
|
26
|
+
|
|
27
|
+
def on_epoch_begin(self, epoch, logs=None):
|
|
28
|
+
self.epoch_start_time = time.time()
|
|
29
|
+
|
|
30
|
+
def on_batch_end(self, batch, logs=None):
|
|
31
|
+
# Update frame bar (bottom bar) for batch progress
|
|
32
|
+
# logs has 'size' and 'batch'
|
|
33
|
+
# We need total batches. Keras doesn't always pass it easily in logs unless we know steps_per_epoch.
|
|
34
|
+
# But self.params['steps'] should have it if available.
|
|
35
|
+
if self.params and "steps" in self.params:
|
|
36
|
+
total_steps = self.params["steps"]
|
|
37
|
+
batch_progress = ((batch + 1) / total_steps) * 100
|
|
38
|
+
if self.queue is not None:
|
|
39
|
+
# Send generic batch update (frequent)
|
|
40
|
+
self.queue.put(
|
|
41
|
+
{
|
|
42
|
+
"frame_progress": batch_progress,
|
|
43
|
+
"frame_time": f"Batch {batch + 1}/{total_steps}",
|
|
44
|
+
}
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
48
|
+
if self.stop_event and self.stop_event.is_set():
|
|
49
|
+
logger.info("Interrupting training...")
|
|
50
|
+
self.model.stop_training = True
|
|
51
|
+
self.stop_event.clear()
|
|
52
|
+
|
|
53
|
+
self.current_step += 1
|
|
54
|
+
# Send signal for progress bar
|
|
55
|
+
sum_done = (self.current_step) / self.total_epochs * 100
|
|
56
|
+
mean_exec_per_step = (time.time() - self.t0) / (self.current_step)
|
|
57
|
+
pred_time = (self.total_epochs - self.current_step) * mean_exec_per_step
|
|
58
|
+
|
|
59
|
+
# Format time string
|
|
60
|
+
if pred_time > 60:
|
|
61
|
+
time_str = f"{pred_time/60:.1f} min"
|
|
62
|
+
else:
|
|
63
|
+
time_str = f"{pred_time:.1f} s"
|
|
64
|
+
|
|
65
|
+
if self.queue is not None:
|
|
66
|
+
# Update Position bar (middle) for Epoch progress
|
|
67
|
+
msg = {
|
|
68
|
+
"pos_progress": sum_done,
|
|
69
|
+
"pos_time": f"Epoch {self.current_step}/{self.total_epochs} (ETA: {time_str})",
|
|
70
|
+
"frame_progress": 0, # Reset batch bar
|
|
71
|
+
"frame_time": "Batch 0/0",
|
|
72
|
+
}
|
|
73
|
+
# Attempt to extract metrics for plotting
|
|
74
|
+
if logs:
|
|
75
|
+
# Infer model type
|
|
76
|
+
if "iou" in logs:
|
|
77
|
+
model_name = "Classifier"
|
|
78
|
+
else:
|
|
79
|
+
model_name = "Regressor"
|
|
80
|
+
|
|
81
|
+
# Send all scalar metrics
|
|
82
|
+
msg["plot_data"] = {
|
|
83
|
+
"epoch": epoch + 1, # 1-based for plot
|
|
84
|
+
"metrics": {
|
|
85
|
+
k: float(v) for k, v in logs.items() if not k.startswith("val_")
|
|
86
|
+
},
|
|
87
|
+
"val_metrics": {
|
|
88
|
+
k: float(v) for k, v in logs.items() if k.startswith("val_")
|
|
89
|
+
},
|
|
90
|
+
"model_name": model_name,
|
|
91
|
+
"total_epochs": self.params.get("epochs", self.total_epochs),
|
|
92
|
+
}
|
|
93
|
+
self.queue.put(msg)
|
|
94
|
+
|
|
95
|
+
def on_training_result(self, result):
|
|
96
|
+
if self.queue is not None:
|
|
97
|
+
self.queue.put({"training_result": result})
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class TrainSignalModelProcess(Process):
|
|
101
|
+
|
|
102
|
+
def __init__(self, queue=None, process_args=None, *args, **kwargs):
|
|
103
|
+
|
|
104
|
+
super().__init__(*args, **kwargs)
|
|
105
|
+
|
|
106
|
+
self.queue = queue
|
|
107
|
+
|
|
108
|
+
if process_args is not None:
|
|
109
|
+
for key, value in process_args.items():
|
|
110
|
+
setattr(self, key, value)
|
|
111
|
+
|
|
112
|
+
tprint("Train event detection")
|
|
113
|
+
self.read_instructions()
|
|
114
|
+
self.extract_training_params()
|
|
115
|
+
|
|
116
|
+
self.sum_done = 0
|
|
117
|
+
self.t0 = time.time()
|
|
118
|
+
|
|
119
|
+
def read_instructions(self):
|
|
120
|
+
|
|
121
|
+
if os.path.exists(self.instructions):
|
|
122
|
+
with open(self.instructions, "r") as f:
|
|
123
|
+
self.training_instructions = json.load(f)
|
|
124
|
+
else:
|
|
125
|
+
logger.error("Training instructions could not be found. Abort.")
|
|
126
|
+
self.abort_process()
|
|
127
|
+
|
|
128
|
+
all_classes = []
|
|
129
|
+
for d in self.training_instructions["ds"]:
|
|
130
|
+
datasets = glob(d + os.sep + "*.npy")
|
|
131
|
+
for dd in datasets:
|
|
132
|
+
data = np.load(dd, allow_pickle=True)
|
|
133
|
+
classes = np.unique([ddd["class"] for ddd in data])
|
|
134
|
+
all_classes.extend(classes)
|
|
135
|
+
all_classes = np.unique(all_classes)
|
|
136
|
+
n_classes = len(all_classes)
|
|
137
|
+
|
|
138
|
+
self.model_params = {
|
|
139
|
+
k: self.training_instructions[k]
|
|
140
|
+
for k in (
|
|
141
|
+
"pretrained",
|
|
142
|
+
"model_signal_length",
|
|
143
|
+
"channel_option",
|
|
144
|
+
"n_channels",
|
|
145
|
+
"label",
|
|
146
|
+
)
|
|
147
|
+
if k in self.training_instructions
|
|
148
|
+
}
|
|
149
|
+
self.model_params.update({"n_classes": n_classes})
|
|
150
|
+
self.train_params = {
|
|
151
|
+
k: self.training_instructions[k]
|
|
152
|
+
for k in (
|
|
153
|
+
"model_name",
|
|
154
|
+
"target_directory",
|
|
155
|
+
"channel_option",
|
|
156
|
+
"recompile_pretrained",
|
|
157
|
+
"test_split",
|
|
158
|
+
"augment",
|
|
159
|
+
"epochs",
|
|
160
|
+
"learning_rate",
|
|
161
|
+
"batch_size",
|
|
162
|
+
"validation_split",
|
|
163
|
+
"normalization_percentile",
|
|
164
|
+
"normalization_values",
|
|
165
|
+
"normalization_clip",
|
|
166
|
+
)
|
|
167
|
+
if k in self.training_instructions
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
def neighborhood_postprocessing(self):
|
|
171
|
+
|
|
172
|
+
# if neighborhood of interest in training instructions, write it in config!
|
|
173
|
+
if "neighborhood_of_interest" in self.training_instructions:
|
|
174
|
+
if self.training_instructions["neighborhood_of_interest"] is not None:
|
|
175
|
+
|
|
176
|
+
model_path = locate_signal_model(
|
|
177
|
+
self.training_instructions["model_name"], path=None, pairs=True
|
|
178
|
+
)
|
|
179
|
+
complete_path = model_path # +model
|
|
180
|
+
complete_path = rf"{complete_path}"
|
|
181
|
+
model_config_path = os.sep.join([complete_path, "config_input.json"])
|
|
182
|
+
model_config_path = rf"{model_config_path}"
|
|
183
|
+
|
|
184
|
+
f = open(model_config_path)
|
|
185
|
+
config = json.load(f)
|
|
186
|
+
config.update(
|
|
187
|
+
{
|
|
188
|
+
"neighborhood_of_interest": self.training_instructions[
|
|
189
|
+
"neighborhood_of_interest"
|
|
190
|
+
],
|
|
191
|
+
"reference_population": self.training_instructions[
|
|
192
|
+
"reference_population"
|
|
193
|
+
],
|
|
194
|
+
"neighbor_population": self.training_instructions[
|
|
195
|
+
"neighbor_population"
|
|
196
|
+
],
|
|
197
|
+
}
|
|
198
|
+
)
|
|
199
|
+
json_string = json.dumps(config)
|
|
200
|
+
with open(model_config_path, "w") as outfile:
|
|
201
|
+
outfile.write(json_string)
|
|
202
|
+
|
|
203
|
+
def run(self):
|
|
204
|
+
self.queue.put({"status": "Loading datasets..."})
|
|
205
|
+
model = SignalDetectionModel(**self.model_params)
|
|
206
|
+
|
|
207
|
+
total_epochs = self.train_params["epochs"] * 3
|
|
208
|
+
cb = ProgressCallback(
|
|
209
|
+
queue=self.queue,
|
|
210
|
+
total_epochs=total_epochs,
|
|
211
|
+
stop_event=getattr(self, "stop_event", None),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
model.fit_from_directory(
|
|
215
|
+
self.training_instructions["ds"], callbacks=[cb], **self.train_params
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Send results to GUI
|
|
219
|
+
if hasattr(model, "dico"):
|
|
220
|
+
result_keys = [
|
|
221
|
+
"val_confusion",
|
|
222
|
+
"test_confusion",
|
|
223
|
+
"val_predictions",
|
|
224
|
+
"val_ground_truth",
|
|
225
|
+
"test_predictions",
|
|
226
|
+
"test_ground_truth",
|
|
227
|
+
"val_mse",
|
|
228
|
+
]
|
|
229
|
+
results = {k: model.dico[k] for k in result_keys if k in model.dico}
|
|
230
|
+
# Only send if we have something relevant
|
|
231
|
+
if results:
|
|
232
|
+
self.queue.put({"training_result": results})
|
|
233
|
+
|
|
234
|
+
self.neighborhood_postprocessing()
|
|
235
|
+
self.queue.put("finished")
|
|
236
|
+
self.queue.close()
|
|
237
|
+
|
|
238
|
+
def extract_training_params(self):
|
|
239
|
+
|
|
240
|
+
self.training_instructions.update(
|
|
241
|
+
{"n_channels": len(self.training_instructions["channel_option"])}
|
|
242
|
+
)
|
|
243
|
+
self.model_params["n_channels"] = self.training_instructions["n_channels"]
|
|
244
|
+
if self.training_instructions["augmentation_factor"] > 1.0:
|
|
245
|
+
self.training_instructions.update({"augment": True})
|
|
246
|
+
else:
|
|
247
|
+
self.training_instructions.update({"augment": False})
|
|
248
|
+
self.training_instructions.update({"test_split": 0.0})
|
|
249
|
+
|
|
250
|
+
def end_process(self):
|
|
251
|
+
|
|
252
|
+
# self.terminate()
|
|
253
|
+
|
|
254
|
+
# if self.model_type=="stardist_utils":
|
|
255
|
+
# from stardist_utils.models import StarDist2D
|
|
256
|
+
# self.model = StarDist2D(None, name=self.model_name, basedir=self.target_directory)
|
|
257
|
+
# self.model.optimize_thresholds(self.X_val,self.Y_val)
|
|
258
|
+
|
|
259
|
+
self.terminate()
|
|
260
|
+
self.queue.put("finished")
|
|
261
|
+
|
|
262
|
+
def abort_process(self):
|
|
263
|
+
|
|
264
|
+
self.terminate()
|
|
265
|
+
self.queue.put("error")
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import os
|
|
3
|
+
import gc
|
|
4
|
+
from multiprocessing import Process
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from celldetective.log_manager import get_logger
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class UnifiedBatchProcess(Process):
|
|
13
|
+
"""
|
|
14
|
+
A unified process that handles Segmentation, Tracking, Measurement, and Signal Analysis
|
|
15
|
+
in a sequential manner for each position, updating a multi-bar progress window.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, queue, process_args=None):
|
|
19
|
+
super(UnifiedBatchProcess, self).__init__()
|
|
20
|
+
self.queue = queue
|
|
21
|
+
self.process_args = process_args
|
|
22
|
+
self.batch_structure = process_args.get("batch_structure", {})
|
|
23
|
+
|
|
24
|
+
self.run_segmentation = process_args.get("run_segmentation", False)
|
|
25
|
+
self.run_tracking = process_args.get("run_tracking", False)
|
|
26
|
+
self.run_measurement = process_args.get("run_measurement", False)
|
|
27
|
+
self.run_signals = process_args.get("run_signals", False)
|
|
28
|
+
|
|
29
|
+
self.seg_args = process_args.get("seg_args", {})
|
|
30
|
+
self.track_args = process_args.get("track_args", {})
|
|
31
|
+
self.measure_args = process_args.get("measure_args", {})
|
|
32
|
+
self.signal_args = process_args.get("signal_args", {})
|
|
33
|
+
self.log_file = process_args.get("log_file", None)
|
|
34
|
+
|
|
35
|
+
def run(self):
|
|
36
|
+
|
|
37
|
+
if self.log_file is not None:
|
|
38
|
+
from celldetective.log_manager import setup_logging
|
|
39
|
+
|
|
40
|
+
setup_logging(self.log_file)
|
|
41
|
+
|
|
42
|
+
logger.info("Starting Unified Batch Process...")
|
|
43
|
+
|
|
44
|
+
# Initialize Workers
|
|
45
|
+
# Propagate batch structure to sub-processes so they can locate experiment config
|
|
46
|
+
for args_dict in [
|
|
47
|
+
self.seg_args,
|
|
48
|
+
self.track_args,
|
|
49
|
+
self.measure_args,
|
|
50
|
+
self.signal_args,
|
|
51
|
+
]:
|
|
52
|
+
if isinstance(args_dict, dict):
|
|
53
|
+
args_dict["batch_structure"] = self.batch_structure
|
|
54
|
+
|
|
55
|
+
seg_worker = None
|
|
56
|
+
model = None
|
|
57
|
+
scale_model = None
|
|
58
|
+
|
|
59
|
+
if self.run_segmentation:
|
|
60
|
+
logger.info("Initializing the segmentation worker...")
|
|
61
|
+
self.queue.put({"status": "Initializing segmentation..."})
|
|
62
|
+
|
|
63
|
+
if "threshold_instructions" in self.seg_args:
|
|
64
|
+
from celldetective.processes.segment_cells import (
|
|
65
|
+
SegmentCellThresholdProcess,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
seg_worker = SegmentCellThresholdProcess(
|
|
69
|
+
queue=self.queue, process_args=self.seg_args
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
from celldetective.processes.segment_cells import SegmentCellDLProcess
|
|
73
|
+
|
|
74
|
+
seg_worker = SegmentCellDLProcess(
|
|
75
|
+
queue=self.queue, process_args=self.seg_args
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if seg_worker.model_type == "stardist":
|
|
79
|
+
logger.info("Loading the StarDist library...")
|
|
80
|
+
from celldetective.utils.stardist_utils import _prep_stardist_model
|
|
81
|
+
|
|
82
|
+
model, scale_model = _prep_stardist_model(
|
|
83
|
+
seg_worker.model_name,
|
|
84
|
+
Path(seg_worker.model_complete_path).parent,
|
|
85
|
+
use_gpu=seg_worker.use_gpu,
|
|
86
|
+
scale=seg_worker.scale,
|
|
87
|
+
)
|
|
88
|
+
elif seg_worker.model_type == "cellpose":
|
|
89
|
+
logger.info("Loading the cellpose_utils library...")
|
|
90
|
+
from celldetective.utils.cellpose_utils import _prep_cellpose_model
|
|
91
|
+
|
|
92
|
+
model, scale_model = _prep_cellpose_model(
|
|
93
|
+
seg_worker.model_name,
|
|
94
|
+
seg_worker.model_complete_path,
|
|
95
|
+
use_gpu=seg_worker.use_gpu,
|
|
96
|
+
n_channels=len(seg_worker.required_channels),
|
|
97
|
+
scale=seg_worker.scale,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
track_worker = None
|
|
101
|
+
if self.run_tracking:
|
|
102
|
+
from celldetective.processes.track_cells import TrackingProcess
|
|
103
|
+
|
|
104
|
+
logger.info("Initializing the tracking worker...")
|
|
105
|
+
self.queue.put({"status": "Initializing tracking..."})
|
|
106
|
+
track_worker = TrackingProcess(
|
|
107
|
+
queue=self.queue, process_args=self.track_args
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
measure_worker = None
|
|
111
|
+
if self.run_measurement:
|
|
112
|
+
logger.info("Loading the measurement libraries...")
|
|
113
|
+
from celldetective.processes.measure_cells import MeasurementProcess
|
|
114
|
+
|
|
115
|
+
logger.info("Initializing the measurement worker...")
|
|
116
|
+
self.queue.put({"status": "Initializing measurements..."})
|
|
117
|
+
measure_worker = MeasurementProcess(
|
|
118
|
+
queue=self.queue, process_args=self.measure_args
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
signal_worker = None
|
|
122
|
+
signal_model = None
|
|
123
|
+
|
|
124
|
+
if self.run_signals:
|
|
125
|
+
from celldetective.utils.event_detection import _prep_event_detection_model
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
logger.info("Loading the event detection model...")
|
|
129
|
+
self.queue.put({"status": "Loading event detection model..."})
|
|
130
|
+
model_name = self.signal_args["model_name"]
|
|
131
|
+
signal_model = _prep_event_detection_model(
|
|
132
|
+
model_name, use_gpu=self.signal_args.get("gpu", True)
|
|
133
|
+
)
|
|
134
|
+
except Exception as e:
|
|
135
|
+
logger.error(
|
|
136
|
+
f"Failed to initialize event detection model: {e}", exc_info=True
|
|
137
|
+
)
|
|
138
|
+
self.run_signals = False # Disable signal analysis if model fails
|
|
139
|
+
|
|
140
|
+
if self.run_signals:
|
|
141
|
+
from celldetective.processes.detect_events import SignalAnalysisProcess
|
|
142
|
+
|
|
143
|
+
logger.info("Initializing the event detection worker...")
|
|
144
|
+
signal_worker = SignalAnalysisProcess(
|
|
145
|
+
queue=self.queue, process_args=self.signal_args
|
|
146
|
+
)
|
|
147
|
+
signal_worker.signal_model_instance = signal_model
|
|
148
|
+
|
|
149
|
+
self.t0_well = time.time()
|
|
150
|
+
|
|
151
|
+
for w_i, (w_idx, well_data) in enumerate(self.batch_structure.items()):
|
|
152
|
+
|
|
153
|
+
positions = well_data["positions"]
|
|
154
|
+
|
|
155
|
+
# Well Progress Update
|
|
156
|
+
elapsed = time.time() - self.t0_well
|
|
157
|
+
if w_i > 0:
|
|
158
|
+
avg = elapsed / w_i
|
|
159
|
+
rem = (len(self.batch_structure) - w_i) * avg
|
|
160
|
+
mins = int(rem // 60)
|
|
161
|
+
secs = int(rem % 60)
|
|
162
|
+
well_str = f"Well {w_i + 1}/{len(self.batch_structure)} - {mins} m {secs} s left"
|
|
163
|
+
else:
|
|
164
|
+
well_str = f"Processing well {w_i + 1}/{len(self.batch_structure)}..."
|
|
165
|
+
|
|
166
|
+
self.queue.put(
|
|
167
|
+
{
|
|
168
|
+
"well_progress": (w_i / len(self.batch_structure)) * 100,
|
|
169
|
+
"well_time": well_str,
|
|
170
|
+
}
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
self.t0_pos = time.time()
|
|
174
|
+
|
|
175
|
+
for pos_idx, pos_path in enumerate(positions):
|
|
176
|
+
|
|
177
|
+
# Position Progress Update
|
|
178
|
+
elapsed_pos = time.time() - self.t0_pos
|
|
179
|
+
if pos_idx > 0:
|
|
180
|
+
avg_p = elapsed_pos / pos_idx
|
|
181
|
+
rem_p = (len(positions) - pos_idx) * avg_p
|
|
182
|
+
mins_p = int(rem_p // 60)
|
|
183
|
+
secs_p = int(rem_p % 60)
|
|
184
|
+
pos_str = f"Pos {pos_idx + 1}/{len(positions)} - {mins_p} m {secs_p} s left"
|
|
185
|
+
else:
|
|
186
|
+
pos_str = f"Processing position {pos_idx + 1}/{len(positions)}..."
|
|
187
|
+
|
|
188
|
+
self.queue.put(
|
|
189
|
+
{
|
|
190
|
+
"pos_progress": (pos_idx / len(positions)) * 100,
|
|
191
|
+
"pos_time": pos_str,
|
|
192
|
+
}
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Calculate active steps for this run
|
|
196
|
+
active_steps = []
|
|
197
|
+
if self.run_segmentation:
|
|
198
|
+
active_steps.append("Segmentation")
|
|
199
|
+
if self.run_tracking:
|
|
200
|
+
active_steps.append("Tracking")
|
|
201
|
+
if self.run_measurement:
|
|
202
|
+
active_steps.append("Measurement")
|
|
203
|
+
if self.run_signals:
|
|
204
|
+
active_steps.append("Event detection")
|
|
205
|
+
|
|
206
|
+
total_steps = len(active_steps)
|
|
207
|
+
current_step = 0
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
# --- SEGMENTATION ---
|
|
211
|
+
if self.run_segmentation and seg_worker:
|
|
212
|
+
current_step += 1
|
|
213
|
+
step_info = f"[Step {current_step}/{total_steps}]"
|
|
214
|
+
msg = f"{step_info} Segmenting {os.path.basename(pos_path)}..."
|
|
215
|
+
logger.info(msg)
|
|
216
|
+
self.queue.put({"status": msg})
|
|
217
|
+
|
|
218
|
+
seg_worker.setup_for_position(pos_path)
|
|
219
|
+
|
|
220
|
+
if not "threshold_instructions" in self.seg_args:
|
|
221
|
+
seg_worker.process_position(
|
|
222
|
+
model=model, scale_model=scale_model
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
seg_worker.process_position()
|
|
226
|
+
|
|
227
|
+
# --- TRACKING ---
|
|
228
|
+
if self.run_tracking and track_worker:
|
|
229
|
+
current_step += 1
|
|
230
|
+
step_info = f"[Step {current_step}/{total_steps}]"
|
|
231
|
+
msg = f"{step_info} Tracking {os.path.basename(pos_path)}..."
|
|
232
|
+
logger.info(msg)
|
|
233
|
+
self.queue.put({"status": msg})
|
|
234
|
+
|
|
235
|
+
track_worker.setup_for_position(pos_path)
|
|
236
|
+
track_worker.process_position()
|
|
237
|
+
|
|
238
|
+
# --- MEASUREMENT ---
|
|
239
|
+
if self.run_measurement and measure_worker:
|
|
240
|
+
current_step += 1
|
|
241
|
+
step_info = f"[Step {current_step}/{total_steps}]"
|
|
242
|
+
msg = f"{step_info} Measuring {os.path.basename(pos_path)}..."
|
|
243
|
+
logger.info(msg)
|
|
244
|
+
self.queue.put({"status": msg})
|
|
245
|
+
|
|
246
|
+
measure_worker.setup_for_position(pos_path)
|
|
247
|
+
measure_worker.process_position()
|
|
248
|
+
|
|
249
|
+
# --- SIGNAL ANALYSIS ---
|
|
250
|
+
if self.run_signals and signal_worker:
|
|
251
|
+
current_step += 1
|
|
252
|
+
step_info = f"[Step {current_step}/{total_steps}]"
|
|
253
|
+
msg = f"{step_info} Detecting events in position {os.path.basename(pos_path)}..."
|
|
254
|
+
logger.info(msg)
|
|
255
|
+
self.queue.put({"status": msg})
|
|
256
|
+
|
|
257
|
+
signal_worker.setup_for_position(pos_path)
|
|
258
|
+
signal_worker.process_position(model=signal_model)
|
|
259
|
+
|
|
260
|
+
except Exception as e:
|
|
261
|
+
logger.error(f"Error processing position {pos_path}: {e}")
|
|
262
|
+
self.queue.put(
|
|
263
|
+
{
|
|
264
|
+
"status": f"Error at {os.path.basename(pos_path)}. Skipping..."
|
|
265
|
+
}
|
|
266
|
+
)
|
|
267
|
+
logger.error(
|
|
268
|
+
f"Skipping position {os.path.basename(pos_path)} due to error: {e}",
|
|
269
|
+
exc_info=True,
|
|
270
|
+
)
|
|
271
|
+
continue
|
|
272
|
+
|
|
273
|
+
gc.collect()
|
|
274
|
+
|
|
275
|
+
# Update Position Progress (Complete)
|
|
276
|
+
self.queue.put({"pos_progress": ((pos_idx + 1) / len(positions)) * 100})
|
|
277
|
+
|
|
278
|
+
# Update Well Progress (Complete)
|
|
279
|
+
self.queue.put(
|
|
280
|
+
{"well_progress": ((w_i + 1) / len(self.batch_structure)) * 100}
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
self.queue.put("finished")
|
|
284
|
+
self.queue.close()
|
|
285
|
+
|
|
286
|
+
def end_process(self):
|
|
287
|
+
try:
|
|
288
|
+
if self.is_alive():
|
|
289
|
+
self.terminate()
|
|
290
|
+
self.join()
|
|
291
|
+
except Exception as e:
|
|
292
|
+
logger.error(f"Error terminating process: {e}")
|