cellfinder 1.6.0__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cellfinder/cli_migration_warning.py +3 -1
- cellfinder/core/classify/classify.py +50 -4
- cellfinder/core/classify/tools.py +13 -3
- cellfinder/core/detect/detect.py +73 -58
- cellfinder/core/detect/filters/plane/plane_filter.py +1 -1
- cellfinder/core/detect/filters/setup_filters.py +31 -12
- cellfinder/core/detect/filters/volume/ball_filter.py +5 -5
- cellfinder/core/detect/filters/volume/structure_splitting.py +2 -0
- cellfinder/core/detect/filters/volume/volume_filter.py +1 -1
- cellfinder/core/download/download.py +2 -1
- cellfinder/core/main.py +130 -16
- cellfinder/core/tools/threading.py +4 -3
- cellfinder/core/train/train_yaml.py +4 -13
- cellfinder/napari/curation.py +18 -2
- cellfinder/napari/detect/detect.py +61 -26
- cellfinder/napari/detect/detect_containers.py +31 -8
- cellfinder/napari/input_container.py +14 -4
- cellfinder/napari/train/train.py +3 -7
- cellfinder/napari/train/train_containers.py +0 -2
- {cellfinder-1.6.0.dist-info → cellfinder-1.8.0.dist-info}/METADATA +5 -4
- {cellfinder-1.6.0.dist-info → cellfinder-1.8.0.dist-info}/RECORD +25 -25
- {cellfinder-1.6.0.dist-info → cellfinder-1.8.0.dist-info}/WHEEL +1 -1
- {cellfinder-1.6.0.dist-info → cellfinder-1.8.0.dist-info}/entry_points.txt +0 -0
- {cellfinder-1.6.0.dist-info → cellfinder-1.8.0.dist-info/licenses}/LICENSE +0 -0
- {cellfinder-1.6.0.dist-info → cellfinder-1.8.0.dist-info}/top_level.txt +0 -0
cellfinder/core/main.py
CHANGED
|
@@ -1,34 +1,33 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Callable, List, Optional, Tuple
|
|
3
3
|
|
|
4
|
-
import numpy as np
|
|
5
4
|
from brainglobe_utils.cells.cells import Cell
|
|
6
5
|
|
|
7
|
-
from cellfinder.core import logger
|
|
6
|
+
from cellfinder.core import logger, types
|
|
8
7
|
from cellfinder.core.download.download import model_type
|
|
9
8
|
from cellfinder.core.train.train_yaml import depth_type
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
def main(
|
|
13
|
-
signal_array:
|
|
14
|
-
background_array:
|
|
15
|
-
voxel_sizes: Tuple[
|
|
12
|
+
signal_array: types.array,
|
|
13
|
+
background_array: types.array,
|
|
14
|
+
voxel_sizes: Tuple[float, float, float],
|
|
16
15
|
start_plane: int = 0,
|
|
17
16
|
end_plane: int = -1,
|
|
18
17
|
trained_model: Optional[os.PathLike] = None,
|
|
19
18
|
model_weights: Optional[os.PathLike] = None,
|
|
20
19
|
model: model_type = "resnet50_tv",
|
|
21
|
-
|
|
20
|
+
classification_batch_size: int = 64,
|
|
22
21
|
n_free_cpus: int = 2,
|
|
23
|
-
network_voxel_sizes: Tuple[
|
|
24
|
-
soma_diameter:
|
|
25
|
-
ball_xy_size:
|
|
26
|
-
ball_z_size:
|
|
22
|
+
network_voxel_sizes: Tuple[float, float, float] = (5, 1, 1),
|
|
23
|
+
soma_diameter: float = 16,
|
|
24
|
+
ball_xy_size: float = 6,
|
|
25
|
+
ball_z_size: float = 15,
|
|
27
26
|
ball_overlap_fraction: float = 0.6,
|
|
28
27
|
log_sigma_size: float = 0.2,
|
|
29
28
|
n_sds_above_mean_thresh: float = 10,
|
|
30
29
|
soma_spread_factor: float = 1.4,
|
|
31
|
-
max_cluster_size:
|
|
30
|
+
max_cluster_size: float = 100000,
|
|
32
31
|
cube_width: int = 50,
|
|
33
32
|
cube_height: int = 50,
|
|
34
33
|
cube_depth: int = 20,
|
|
@@ -36,8 +35,13 @@ def main(
|
|
|
36
35
|
skip_detection: bool = False,
|
|
37
36
|
skip_classification: bool = False,
|
|
38
37
|
detected_cells: List[Cell] = None,
|
|
39
|
-
|
|
40
|
-
|
|
38
|
+
detection_batch_size: Optional[int] = None,
|
|
39
|
+
torch_device: Optional[str] = None,
|
|
40
|
+
pin_memory: bool = False,
|
|
41
|
+
split_ball_xy_size: float = 6,
|
|
42
|
+
split_ball_z_size: float = 15,
|
|
43
|
+
split_ball_overlap_fraction: float = 0.8,
|
|
44
|
+
n_splitting_iter: int = 10,
|
|
41
45
|
*,
|
|
42
46
|
detect_callback: Optional[Callable[[int], None]] = None,
|
|
43
47
|
classify_callback: Optional[Callable[[int], None]] = None,
|
|
@@ -46,6 +50,111 @@ def main(
|
|
|
46
50
|
"""
|
|
47
51
|
Parameters
|
|
48
52
|
----------
|
|
53
|
+
signal_array : numpy.ndarray or dask array
|
|
54
|
+
3D array representing the signal data in z, y, x order.
|
|
55
|
+
background_array : numpy.ndarray or dask array
|
|
56
|
+
3D array representing the signal data in z, y, x order.
|
|
57
|
+
voxel_sizes : 3-tuple of floats
|
|
58
|
+
Size of your voxels in the z, y, and x dimensions (microns).
|
|
59
|
+
start_plane : int
|
|
60
|
+
First plane index to process (inclusive, to process a subset of the
|
|
61
|
+
data).
|
|
62
|
+
end_plane : int
|
|
63
|
+
Last plane index to process (exclusive, to process a subset of the
|
|
64
|
+
data).
|
|
65
|
+
trained_model : Optional[Path]
|
|
66
|
+
Trained model file path (home directory (default) -> pretrained
|
|
67
|
+
weights).
|
|
68
|
+
model_weights : Optional[Path]
|
|
69
|
+
Model weights path (home directory (default) -> pretrained
|
|
70
|
+
weights).
|
|
71
|
+
model: str
|
|
72
|
+
Type of model to use. Defaults to `"resnet50_tv"`.
|
|
73
|
+
classification_batch_size : int
|
|
74
|
+
How many potential cells to classify at one time. The GPU/CPU
|
|
75
|
+
memory must be able to contain at once this many data cubes for
|
|
76
|
+
the models. For performance-critical applications, tune to maximize
|
|
77
|
+
memory usage without running out. Check your GPU/CPU memory to verify
|
|
78
|
+
it's not full.
|
|
79
|
+
n_free_cpus : int
|
|
80
|
+
How many CPU cores to leave free.
|
|
81
|
+
network_voxel_sizes : 3-tuple of floats
|
|
82
|
+
Size of the pre-trained network's voxels (microns) in the z, y, and x
|
|
83
|
+
dimensions.
|
|
84
|
+
soma_diameter : float
|
|
85
|
+
The expected in-plane (xy) soma diameter (microns).
|
|
86
|
+
ball_xy_size : float
|
|
87
|
+
3d filter's in-plane (xy) filter ball size (microns).
|
|
88
|
+
ball_z_size : float
|
|
89
|
+
3d filter's axial (z) filter ball size (microns).
|
|
90
|
+
ball_overlap_fraction : float
|
|
91
|
+
3d filter's fraction of the ball filter needed to be filled by
|
|
92
|
+
foreground voxels, centered on a voxel, to retain the voxel.
|
|
93
|
+
log_sigma_size : float
|
|
94
|
+
Gaussian filter width (as a fraction of soma diameter) used during
|
|
95
|
+
2d in-plane Laplacian of Gaussian filtering.
|
|
96
|
+
n_sds_above_mean_thresh : float
|
|
97
|
+
Intensity threshold (the number of standard deviations above
|
|
98
|
+
the mean) of the filtered 2d planes used to mark pixels as
|
|
99
|
+
foreground or background.
|
|
100
|
+
soma_spread_factor : float
|
|
101
|
+
Cell spread factor for determining the largest cell volume before
|
|
102
|
+
splitting up cell clusters. Structures with spherical volume of
|
|
103
|
+
diameter `soma_spread_factor * soma_diameter` or less will not be
|
|
104
|
+
split.
|
|
105
|
+
max_cluster_size : float
|
|
106
|
+
Largest detected cell cluster (in cubic um) where splitting
|
|
107
|
+
should be attempted. Clusters above this size will be labeled
|
|
108
|
+
as artifacts.
|
|
109
|
+
cube_width: int
|
|
110
|
+
The width of the data cube centered on the cell used for
|
|
111
|
+
classification. Defaults to `50`.
|
|
112
|
+
cube_height: int
|
|
113
|
+
The height of the data cube centered on the cell used for
|
|
114
|
+
classification. Defaults to `50`.
|
|
115
|
+
cube_depth: int
|
|
116
|
+
The depth of the data cube centered on the cell used for
|
|
117
|
+
classification. Defaults to `20`.
|
|
118
|
+
network_depth: str
|
|
119
|
+
The network depth to use during classification. Defaults to `"50"`.
|
|
120
|
+
skip_detection : bool
|
|
121
|
+
If selected, the detection step is skipped and instead we get the
|
|
122
|
+
detected cells from the cell layer below (from a previous
|
|
123
|
+
detection run or import).
|
|
124
|
+
skip_classification : bool
|
|
125
|
+
If selected, the classification step is skipped and all cells from
|
|
126
|
+
the detection stage are added.
|
|
127
|
+
detected_cells: Optional list of Cell objects.
|
|
128
|
+
If specified, the cells to use during classification.
|
|
129
|
+
detection_batch_size: int
|
|
130
|
+
The number of planes of the original data volume to process at
|
|
131
|
+
once. The GPU/CPU memory must be able to contain this many planes
|
|
132
|
+
for all the filters. For performance-critical applications, tune
|
|
133
|
+
to maximize memory usage without running out. Check your GPU/CPU
|
|
134
|
+
memory to verify it's not full.
|
|
135
|
+
torch_device : str, optional
|
|
136
|
+
The device on which to run the computation. If not specified (None),
|
|
137
|
+
"cuda" will be used if a GPU is available, otherwise "cpu".
|
|
138
|
+
You can also manually specify "cuda" or "cpu".
|
|
139
|
+
pin_memory: bool
|
|
140
|
+
Pins data to be sent to the GPU to the CPU memory. This allows faster
|
|
141
|
+
GPU data speeds, but can only be used if the data used by the GPU can
|
|
142
|
+
stay in the CPU RAM while the GPU uses it. I.e. there's enough RAM.
|
|
143
|
+
Otherwise, if there's a risk of the RAM being paged, it shouldn't be
|
|
144
|
+
used. Defaults to False.
|
|
145
|
+
split_ball_xy_size: float
|
|
146
|
+
Similar to `ball_xy_size`, except the value to use for the 3d
|
|
147
|
+
filter during cluster splitting.
|
|
148
|
+
split_ball_z_size: float
|
|
149
|
+
Similar to `ball_z_size`, except the value to use for the 3d filter
|
|
150
|
+
during cluster splitting.
|
|
151
|
+
split_ball_overlap_fraction: float
|
|
152
|
+
Similar to `ball_overlap_fraction`, except the value to use for the
|
|
153
|
+
3d filter during cluster splitting.
|
|
154
|
+
n_splitting_iter: int
|
|
155
|
+
The number of iterations to run the 3d filtering on a cluster. Each
|
|
156
|
+
iteration reduces the cluster size by the voxels not retained in
|
|
157
|
+
the previous iteration.
|
|
49
158
|
detect_callback : Callable[int], optional
|
|
50
159
|
Called every time a plane has finished being processed during the
|
|
51
160
|
detection stage. Called with the plane number that has finished.
|
|
@@ -76,9 +185,14 @@ def main(
|
|
|
76
185
|
n_free_cpus,
|
|
77
186
|
log_sigma_size,
|
|
78
187
|
n_sds_above_mean_thresh,
|
|
79
|
-
batch_size=
|
|
80
|
-
torch_device=
|
|
188
|
+
batch_size=detection_batch_size,
|
|
189
|
+
torch_device=torch_device,
|
|
190
|
+
pin_memory=pin_memory,
|
|
81
191
|
callback=detect_callback,
|
|
192
|
+
split_ball_z_size=split_ball_z_size,
|
|
193
|
+
split_ball_xy_size=split_ball_xy_size,
|
|
194
|
+
split_ball_overlap_fraction=split_ball_overlap_fraction,
|
|
195
|
+
n_splitting_iter=n_splitting_iter,
|
|
82
196
|
)
|
|
83
197
|
|
|
84
198
|
if detect_finished_callback is not None:
|
|
@@ -101,7 +215,7 @@ def main(
|
|
|
101
215
|
n_free_cpus,
|
|
102
216
|
voxel_sizes,
|
|
103
217
|
network_voxel_sizes,
|
|
104
|
-
|
|
218
|
+
classification_batch_size,
|
|
105
219
|
cube_height,
|
|
106
220
|
cube_width,
|
|
107
221
|
cube_depth,
|
|
@@ -15,6 +15,7 @@ Typical example::
|
|
|
15
15
|
|
|
16
16
|
from cellfinder.core.tools.threading import ThreadWithException, \\
|
|
17
17
|
EOFSignal, ProcessWithException
|
|
18
|
+
from cellfinder.core import logger
|
|
18
19
|
import torch
|
|
19
20
|
|
|
20
21
|
|
|
@@ -63,7 +64,7 @@ Typical example::
|
|
|
63
64
|
# thread exited for whatever reason (not exception)
|
|
64
65
|
break
|
|
65
66
|
|
|
66
|
-
|
|
67
|
+
logger.debug(f"Thread processed tensor {i}")
|
|
67
68
|
finally:
|
|
68
69
|
# whatever happens, make sure thread is told to finish so it
|
|
69
70
|
# doesn't get stuck
|
|
@@ -248,8 +249,8 @@ class ExceptionWithQueueMixIn:
|
|
|
248
249
|
... # do something with the msg
|
|
249
250
|
... pass
|
|
250
251
|
... except ExecutionFailure as e:
|
|
251
|
-
...
|
|
252
|
-
...
|
|
252
|
+
... logger.error(f"got exception {type(e.__cause__)}")
|
|
253
|
+
... logger.error(f"with message {e.__cause__.args[0]}")
|
|
253
254
|
"""
|
|
254
255
|
msg, value = self.from_thread_queue.get(block=True, timeout=timeout)
|
|
255
256
|
if msg == "eof":
|
|
@@ -3,9 +3,6 @@ main
|
|
|
3
3
|
===============
|
|
4
4
|
|
|
5
5
|
Trains a network based on a yaml file specifying cubes of cells/non cells.
|
|
6
|
-
|
|
7
|
-
N.B imports are within functions to prevent tensorflow being imported before
|
|
8
|
-
it's warnings are silenced
|
|
9
6
|
"""
|
|
10
7
|
|
|
11
8
|
import os
|
|
@@ -29,12 +26,16 @@ from brainglobe_utils.general.system import (
|
|
|
29
26
|
from brainglobe_utils.IO.cells import find_relevant_tiffs
|
|
30
27
|
from brainglobe_utils.IO.yaml import read_yaml_section
|
|
31
28
|
from fancylog import fancylog
|
|
29
|
+
from keras.callbacks import CSVLogger, ModelCheckpoint, TensorBoard
|
|
32
30
|
from sklearn.model_selection import train_test_split
|
|
33
31
|
|
|
34
32
|
import cellfinder.core as program_for_log
|
|
35
33
|
from cellfinder.core import logger
|
|
34
|
+
from cellfinder.core.classify.cube_generator import CubeGeneratorFromDisk
|
|
36
35
|
from cellfinder.core.classify.resnet import layer_type
|
|
36
|
+
from cellfinder.core.classify.tools import get_model, make_lists
|
|
37
37
|
from cellfinder.core.download.download import DEFAULT_DOWNLOAD_DIRECTORY
|
|
38
|
+
from cellfinder.core.tools.prep import prep_model_weights
|
|
38
39
|
|
|
39
40
|
depth_type = Literal["18", "34", "50", "101", "152"]
|
|
40
41
|
|
|
@@ -316,16 +317,6 @@ def run(
|
|
|
316
317
|
save_progress=False,
|
|
317
318
|
epochs=100,
|
|
318
319
|
):
|
|
319
|
-
from keras.callbacks import (
|
|
320
|
-
CSVLogger,
|
|
321
|
-
ModelCheckpoint,
|
|
322
|
-
TensorBoard,
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
from cellfinder.core.classify.cube_generator import CubeGeneratorFromDisk
|
|
326
|
-
from cellfinder.core.classify.tools import get_model, make_lists
|
|
327
|
-
from cellfinder.core.tools.prep import prep_model_weights
|
|
328
|
-
|
|
329
320
|
start_time = datetime.now()
|
|
330
321
|
|
|
331
322
|
ensure_directory_exists(output_dir)
|
cellfinder/napari/curation.py
CHANGED
|
@@ -95,6 +95,22 @@ class CurationWidget(QWidget):
|
|
|
95
95
|
self.training_data_non_cell_choice, self.point_layer_names
|
|
96
96
|
)
|
|
97
97
|
|
|
98
|
+
@self.viewer.layers.events.removed.connect
|
|
99
|
+
def _remove_selection_layers(event: QtCore.QEvent):
|
|
100
|
+
"""
|
|
101
|
+
Set internal background, signal, training data cell,
|
|
102
|
+
and training data non-cell layers to None when they
|
|
103
|
+
are removed from the napari viewer GUI.
|
|
104
|
+
"""
|
|
105
|
+
if event.value == self.signal_layer:
|
|
106
|
+
self.signal_layer = None
|
|
107
|
+
if event.value == self.background_layer:
|
|
108
|
+
self.background_layer = None
|
|
109
|
+
if event.value == self.training_data_cell_layer:
|
|
110
|
+
self.training_data_cell_layer = None
|
|
111
|
+
if event.value == self.training_data_non_cell_layer:
|
|
112
|
+
self.training_data_non_cell_layer = None
|
|
113
|
+
|
|
98
114
|
@staticmethod
|
|
99
115
|
def _update_combobox_options(combobox: QComboBox, options_list: List[str]):
|
|
100
116
|
original_text = combobox.currentText()
|
|
@@ -212,8 +228,8 @@ class CurationWidget(QWidget):
|
|
|
212
228
|
self.layout.addWidget(self.load_data_panel, row, column, 1, 1)
|
|
213
229
|
|
|
214
230
|
def setup_keybindings(self):
|
|
215
|
-
self.viewer.bind_key("c", self.mark_as_cell)
|
|
216
|
-
self.viewer.bind_key("x", self.mark_as_non_cell)
|
|
231
|
+
self.viewer.bind_key("c", self.mark_as_cell, overwrite=True)
|
|
232
|
+
self.viewer.bind_key("x", self.mark_as_non_cell, overwrite=True)
|
|
217
233
|
|
|
218
234
|
def set_signal_image(self):
|
|
219
235
|
"""
|
|
@@ -244,23 +244,26 @@ def detect_widget() -> FunctionGui:
|
|
|
244
244
|
detection_options,
|
|
245
245
|
skip_detection: bool,
|
|
246
246
|
soma_diameter: float,
|
|
247
|
+
log_sigma_size: float,
|
|
248
|
+
n_sds_above_mean_thresh: float,
|
|
247
249
|
ball_xy_size: float,
|
|
248
250
|
ball_z_size: float,
|
|
249
251
|
ball_overlap_fraction: float,
|
|
250
|
-
|
|
251
|
-
n_sds_above_mean_thresh: int,
|
|
252
|
+
detection_batch_size: int,
|
|
252
253
|
soma_spread_factor: float,
|
|
253
|
-
max_cluster_size:
|
|
254
|
+
max_cluster_size: float,
|
|
254
255
|
classification_options,
|
|
255
256
|
skip_classification: bool,
|
|
256
257
|
use_pre_trained_weights: bool,
|
|
257
258
|
trained_model: Optional[Path],
|
|
258
|
-
|
|
259
|
+
classification_batch_size: int,
|
|
259
260
|
misc_options,
|
|
260
261
|
start_plane: int,
|
|
261
262
|
end_plane: int,
|
|
262
263
|
n_free_cpus: int,
|
|
263
264
|
analyse_local: bool,
|
|
265
|
+
use_gpu: bool,
|
|
266
|
+
pin_memory: bool,
|
|
264
267
|
debug: bool,
|
|
265
268
|
reset_button,
|
|
266
269
|
) -> None:
|
|
@@ -270,43 +273,60 @@ def detect_widget() -> FunctionGui:
|
|
|
270
273
|
Parameters
|
|
271
274
|
----------
|
|
272
275
|
voxel_size_z : float
|
|
273
|
-
Size of your voxels in the axial dimension
|
|
276
|
+
Size of your voxels in the axial dimension (microns)
|
|
274
277
|
voxel_size_y : float
|
|
275
|
-
Size of your voxels in the y direction (top to bottom)
|
|
278
|
+
Size of your voxels in the y direction (top to bottom) (microns)
|
|
276
279
|
voxel_size_x : float
|
|
277
|
-
Size of your voxels in the x direction (left to right)
|
|
280
|
+
Size of your voxels in the x direction (left to right) (microns)
|
|
278
281
|
skip_detection : bool
|
|
279
282
|
If selected, the detection step is skipped and instead we get the
|
|
280
283
|
detected cells from the cell layer below (from a previous
|
|
281
284
|
detection run or import)
|
|
282
285
|
soma_diameter : float
|
|
283
|
-
The expected in-plane soma diameter (microns)
|
|
286
|
+
The expected in-plane (xy) soma diameter (microns)
|
|
287
|
+
log_sigma_size : float
|
|
288
|
+
Gaussian filter width (as a fraction of soma diameter) used during
|
|
289
|
+
2d in-plane Laplacian of Gaussian filtering
|
|
290
|
+
n_sds_above_mean_thresh : float
|
|
291
|
+
Intensity threshold (the number of standard deviations above
|
|
292
|
+
the mean) of the filtered 2d planes used to mark pixels as
|
|
293
|
+
foreground or background
|
|
284
294
|
ball_xy_size : float
|
|
285
|
-
|
|
295
|
+
3d filter's in-plane (xy) filter ball size (microns)
|
|
286
296
|
ball_z_size : float
|
|
287
|
-
|
|
297
|
+
3d filter's axial (z) filter ball size (microns)
|
|
288
298
|
ball_overlap_fraction : float
|
|
289
|
-
|
|
290
|
-
to retain
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
299
|
+
3d filter's fraction of the ball filter needed to be filled by
|
|
300
|
+
foreground voxels, centered on a voxel, to retain the voxel
|
|
301
|
+
detection_batch_size: int
|
|
302
|
+
The number of planes of the original data volume to process at
|
|
303
|
+
once. The GPU/CPU memory must be able to contain this many planes
|
|
304
|
+
for all the filters. For performance-critical applications, tune
|
|
305
|
+
to maximize memory usage without
|
|
306
|
+
running out. Check your GPU/CPU memory to verify it's not full
|
|
295
307
|
soma_spread_factor : float
|
|
296
|
-
Cell spread factor
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
308
|
+
Cell spread factor for determining the largest cell volume before
|
|
309
|
+
splitting up cell clusters. Structures with spherical volume of
|
|
310
|
+
diameter `soma_spread_factor * soma_diameter` or less will not be
|
|
311
|
+
split
|
|
312
|
+
max_cluster_size : float
|
|
313
|
+
Largest detected cell cluster (in cubic um) where splitting
|
|
314
|
+
should be attempted. Clusters above this size will be labeled
|
|
315
|
+
as artifacts
|
|
304
316
|
skip_classification : bool
|
|
305
317
|
If selected, the classification step is skipped and all cells from
|
|
306
318
|
the detection stage are added
|
|
319
|
+
use_pre_trained_weights : bool
|
|
320
|
+
Select to use pre-trained model weights
|
|
307
321
|
trained_model : Optional[Path]
|
|
308
322
|
Trained model file path (home directory (default) -> pretrained
|
|
309
323
|
weights)
|
|
324
|
+
classification_batch_size : int
|
|
325
|
+
How many potential cells to classify at one time. The GPU/CPU
|
|
326
|
+
memory must be able to contain at once this many data cubes for
|
|
327
|
+
the models. For performance-critical applications, tune to
|
|
328
|
+
maximize memory usage without running
|
|
329
|
+
out. Check your GPU/CPU memory to verify it's not full
|
|
310
330
|
start_plane : int
|
|
311
331
|
First plane to process (to process a subset of the data)
|
|
312
332
|
end_plane : int
|
|
@@ -315,6 +335,14 @@ def detect_widget() -> FunctionGui:
|
|
|
315
335
|
How many CPU cores to leave free
|
|
316
336
|
analyse_local : bool
|
|
317
337
|
Only analyse planes around the current position
|
|
338
|
+
use_gpu : bool
|
|
339
|
+
If True, use GPU for processing (if available); otherwise, use CPU.
|
|
340
|
+
pin_memory: bool
|
|
341
|
+
Pins data to be sent to the GPU to the CPU memory. This allows
|
|
342
|
+
faster GPU data speeds, but can only be used if the data used by
|
|
343
|
+
the GPU can stay in the CPU RAM while the GPU uses it. I.e. there's
|
|
344
|
+
enough RAM. Otherwise, if there's a risk of the RAM being paged, it
|
|
345
|
+
shouldn't be used. Defaults to False.
|
|
318
346
|
debug : bool
|
|
319
347
|
Increase logging
|
|
320
348
|
reset_button :
|
|
@@ -370,6 +398,7 @@ def detect_widget() -> FunctionGui:
|
|
|
370
398
|
n_sds_above_mean_thresh,
|
|
371
399
|
soma_spread_factor,
|
|
372
400
|
max_cluster_size,
|
|
401
|
+
detection_batch_size,
|
|
373
402
|
)
|
|
374
403
|
|
|
375
404
|
if use_pre_trained_weights:
|
|
@@ -378,7 +407,7 @@ def detect_widget() -> FunctionGui:
|
|
|
378
407
|
skip_classification,
|
|
379
408
|
use_pre_trained_weights,
|
|
380
409
|
trained_model,
|
|
381
|
-
|
|
410
|
+
classification_batch_size,
|
|
382
411
|
)
|
|
383
412
|
|
|
384
413
|
if analyse_local:
|
|
@@ -389,7 +418,13 @@ def detect_widget() -> FunctionGui:
|
|
|
389
418
|
end_plane = len(signal_image.data)
|
|
390
419
|
|
|
391
420
|
misc_inputs = MiscInputs(
|
|
392
|
-
start_plane,
|
|
421
|
+
start_plane,
|
|
422
|
+
end_plane,
|
|
423
|
+
n_free_cpus,
|
|
424
|
+
analyse_local,
|
|
425
|
+
use_gpu,
|
|
426
|
+
pin_memory,
|
|
427
|
+
debug,
|
|
393
428
|
)
|
|
394
429
|
|
|
395
430
|
worker = Worker(
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import List, Optional
|
|
4
4
|
|
|
5
5
|
import numpy
|
|
6
|
+
import torch
|
|
6
7
|
from brainglobe_utils.cells.cells import Cell
|
|
7
8
|
|
|
8
9
|
from cellfinder.napari.input_container import InputContainer
|
|
@@ -30,7 +31,7 @@ class DataInputs(InputContainer):
|
|
|
30
31
|
self.voxel_size_x,
|
|
31
32
|
)
|
|
32
33
|
# del operator doesn't affect self, because asdict creates a copy of
|
|
33
|
-
#
|
|
34
|
+
# the dict.
|
|
34
35
|
del data_input_dict["voxel_size_z"]
|
|
35
36
|
del data_input_dict["voxel_size_y"]
|
|
36
37
|
del data_input_dict["voxel_size_x"]
|
|
@@ -67,9 +68,10 @@ class DetectionInputs(InputContainer):
|
|
|
67
68
|
ball_z_size: float = 15
|
|
68
69
|
ball_overlap_fraction: float = 0.6
|
|
69
70
|
log_sigma_size: float = 0.2
|
|
70
|
-
n_sds_above_mean_thresh:
|
|
71
|
+
n_sds_above_mean_thresh: float = 10
|
|
71
72
|
soma_spread_factor: float = 1.4
|
|
72
|
-
max_cluster_size:
|
|
73
|
+
max_cluster_size: float = 100000
|
|
74
|
+
detection_batch_size: int = 1
|
|
73
75
|
|
|
74
76
|
def as_core_arguments(self) -> dict:
|
|
75
77
|
return super().as_core_arguments()
|
|
@@ -96,14 +98,17 @@ class DetectionInputs(InputContainer):
|
|
|
96
98
|
"n_sds_above_mean_thresh", custom_label="Threshold"
|
|
97
99
|
),
|
|
98
100
|
soma_spread_factor=cls._custom_widget(
|
|
99
|
-
"soma_spread_factor", custom_label="
|
|
101
|
+
"soma_spread_factor", custom_label="Split cell spread"
|
|
100
102
|
),
|
|
101
103
|
max_cluster_size=cls._custom_widget(
|
|
102
104
|
"max_cluster_size",
|
|
103
|
-
custom_label="
|
|
105
|
+
custom_label="Split max cluster",
|
|
104
106
|
min=0,
|
|
105
107
|
max=10000000,
|
|
106
108
|
),
|
|
109
|
+
detection_batch_size=cls._custom_widget(
|
|
110
|
+
"detection_batch_size", custom_label="Batch size (detection)"
|
|
111
|
+
),
|
|
107
112
|
)
|
|
108
113
|
|
|
109
114
|
|
|
@@ -114,7 +119,7 @@ class ClassificationInputs(InputContainer):
|
|
|
114
119
|
skip_classification: bool = False
|
|
115
120
|
use_pre_trained_weights: bool = True
|
|
116
121
|
trained_model: Optional[Path] = Path.home()
|
|
117
|
-
|
|
122
|
+
classification_batch_size: int = 64
|
|
118
123
|
|
|
119
124
|
def as_core_arguments(self) -> dict:
|
|
120
125
|
args = super().as_core_arguments()
|
|
@@ -132,7 +137,10 @@ class ClassificationInputs(InputContainer):
|
|
|
132
137
|
skip_classification=dict(
|
|
133
138
|
value=cls.defaults()["skip_classification"]
|
|
134
139
|
),
|
|
135
|
-
|
|
140
|
+
classification_batch_size=dict(
|
|
141
|
+
value=cls.defaults()["classification_batch_size"],
|
|
142
|
+
label="Batch size (classification)",
|
|
143
|
+
),
|
|
136
144
|
)
|
|
137
145
|
|
|
138
146
|
|
|
@@ -144,10 +152,14 @@ class MiscInputs(InputContainer):
|
|
|
144
152
|
end_plane: int = 0
|
|
145
153
|
n_free_cpus: int = 2
|
|
146
154
|
analyse_local: bool = False
|
|
155
|
+
use_gpu: bool = field(default_factory=lambda: torch.cuda.is_available())
|
|
156
|
+
pin_memory: bool = False
|
|
147
157
|
debug: bool = False
|
|
148
158
|
|
|
149
159
|
def as_core_arguments(self) -> dict:
|
|
150
160
|
misc_input_dict = super().as_core_arguments()
|
|
161
|
+
misc_input_dict["torch_device"] = "cuda" if self.use_gpu else "cpu"
|
|
162
|
+
del misc_input_dict["use_gpu"]
|
|
151
163
|
del misc_input_dict["debug"]
|
|
152
164
|
del misc_input_dict["analyse_local"]
|
|
153
165
|
return misc_input_dict
|
|
@@ -162,5 +174,16 @@ class MiscInputs(InputContainer):
|
|
|
162
174
|
"n_free_cpus", custom_label="Number of free CPUs"
|
|
163
175
|
),
|
|
164
176
|
analyse_local=dict(value=cls.defaults()["analyse_local"]),
|
|
177
|
+
use_gpu=dict(
|
|
178
|
+
widget_type="CheckBox",
|
|
179
|
+
label="Use GPU",
|
|
180
|
+
value=cls.defaults()["use_gpu"],
|
|
181
|
+
enabled=torch.cuda.is_available(),
|
|
182
|
+
),
|
|
183
|
+
pin_memory=dict(
|
|
184
|
+
widget_type="CheckBox",
|
|
185
|
+
label="Pin data to memory",
|
|
186
|
+
value=cls.defaults()["pin_memory"],
|
|
187
|
+
),
|
|
165
188
|
debug=dict(value=cls.defaults()["debug"]),
|
|
166
189
|
)
|
|
@@ -1,8 +1,18 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
|
-
from dataclasses import
|
|
2
|
+
from dataclasses import dataclass, fields
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
5
|
|
|
6
|
+
def asdict_no_copy(obj: dataclass) -> dict:
|
|
7
|
+
"""
|
|
8
|
+
Similar to `asdict`, except it makes no copies of the field values.
|
|
9
|
+
asdict will do a deep copy of field values that are non-basic objects.
|
|
10
|
+
|
|
11
|
+
It still creates a new dict to return, though.
|
|
12
|
+
"""
|
|
13
|
+
return {field.name: getattr(obj, field.name) for field in fields(obj)}
|
|
14
|
+
|
|
15
|
+
|
|
6
16
|
@dataclass
|
|
7
17
|
class InputContainer:
|
|
8
18
|
"""Base for classes that contain inputs
|
|
@@ -23,7 +33,7 @@ class InputContainer:
|
|
|
23
33
|
# Derived classes are not expected to be particularly
|
|
24
34
|
# slow to instantiate, so use the default constructor
|
|
25
35
|
# to avoid code repetition.
|
|
26
|
-
return
|
|
36
|
+
return asdict_no_copy(cls())
|
|
27
37
|
|
|
28
38
|
@abstractmethod
|
|
29
39
|
def as_core_arguments(self) -> dict:
|
|
@@ -32,10 +42,10 @@ class InputContainer:
|
|
|
32
42
|
The implementation provided here can be re-used in derived classes, if
|
|
33
43
|
convenient.
|
|
34
44
|
"""
|
|
35
|
-
# note that
|
|
45
|
+
# note that asdict_no_copy returns a new instance of a dict,
|
|
36
46
|
# so any subsequent modifications of this dict won't affect the class
|
|
37
47
|
# instance
|
|
38
|
-
return
|
|
48
|
+
return asdict_no_copy(self)
|
|
39
49
|
|
|
40
50
|
@classmethod
|
|
41
51
|
def _custom_widget(
|
cellfinder/napari/train/train.py
CHANGED
|
@@ -25,14 +25,14 @@ def run_training(
|
|
|
25
25
|
optional_training_inputs: OptionalTrainingInputs,
|
|
26
26
|
misc_training_inputs: MiscTrainingInputs,
|
|
27
27
|
):
|
|
28
|
-
|
|
28
|
+
show_info("Running training...")
|
|
29
29
|
train_yaml(
|
|
30
30
|
**training_data_inputs.as_core_arguments(),
|
|
31
31
|
**optional_network_inputs.as_core_arguments(),
|
|
32
32
|
**optional_training_inputs.as_core_arguments(),
|
|
33
33
|
**misc_training_inputs.as_core_arguments(),
|
|
34
34
|
)
|
|
35
|
-
|
|
35
|
+
show_info("Training finished!")
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def training_widget() -> FunctionGui:
|
|
@@ -60,7 +60,6 @@ def training_widget() -> FunctionGui:
|
|
|
60
60
|
continue_training: bool,
|
|
61
61
|
augment: bool,
|
|
62
62
|
tensorboard: bool,
|
|
63
|
-
save_weights: bool,
|
|
64
63
|
save_checkpoints: bool,
|
|
65
64
|
save_progress: bool,
|
|
66
65
|
epochs: int,
|
|
@@ -96,9 +95,6 @@ def training_widget() -> FunctionGui:
|
|
|
96
95
|
Augment the training data to improve generalisation
|
|
97
96
|
tensorboard : bool
|
|
98
97
|
Log to output_directory/tensorboard
|
|
99
|
-
save_weights : bool
|
|
100
|
-
Only store the model weights, and not the full model
|
|
101
|
-
Useful to save storage space
|
|
102
98
|
save_checkpoints : bool
|
|
103
99
|
Store the model at intermediate points during training
|
|
104
100
|
save_progress : bool
|
|
@@ -133,7 +129,6 @@ def training_widget() -> FunctionGui:
|
|
|
133
129
|
continue_training,
|
|
134
130
|
augment,
|
|
135
131
|
tensorboard,
|
|
136
|
-
save_weights,
|
|
137
132
|
save_checkpoints,
|
|
138
133
|
save_progress,
|
|
139
134
|
epochs,
|
|
@@ -147,6 +142,7 @@ def training_widget() -> FunctionGui:
|
|
|
147
142
|
if yaml_files[0] == Path.home(): # type: ignore
|
|
148
143
|
show_info("Please select a YAML file for training")
|
|
149
144
|
else:
|
|
145
|
+
show_info("Starting training process...")
|
|
150
146
|
worker = run_training(
|
|
151
147
|
training_data_inputs,
|
|
152
148
|
optional_network_inputs,
|
|
@@ -75,7 +75,6 @@ class OptionalTrainingInputs(InputContainer):
|
|
|
75
75
|
continue_training: bool = False
|
|
76
76
|
augment: bool = True
|
|
77
77
|
tensorboard: bool = False
|
|
78
|
-
save_weights: bool = False
|
|
79
78
|
save_checkpoints: bool = True
|
|
80
79
|
save_progress: bool = True
|
|
81
80
|
epochs: int = 100
|
|
@@ -98,7 +97,6 @@ class OptionalTrainingInputs(InputContainer):
|
|
|
98
97
|
continue_training=cls._custom_widget("continue_training"),
|
|
99
98
|
augment=cls._custom_widget("augment"),
|
|
100
99
|
tensorboard=cls._custom_widget("tensorboard"),
|
|
101
|
-
save_weights=cls._custom_widget("save_weights"),
|
|
102
100
|
save_checkpoints=cls._custom_widget("save_checkpoints"),
|
|
103
101
|
save_progress=cls._custom_widget("save_progress"),
|
|
104
102
|
epochs=cls._custom_widget("epochs"),
|