cellfinder 1.4.1__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cellfinder/cli_migration_warning.py +3 -1
- cellfinder/core/classify/classify.py +51 -5
- cellfinder/core/classify/tools.py +13 -3
- cellfinder/core/detect/detect.py +94 -59
- cellfinder/core/detect/filters/plane/plane_filter.py +107 -10
- cellfinder/core/detect/filters/setup_filters.py +51 -12
- cellfinder/core/detect/filters/volume/ball_filter.py +5 -5
- cellfinder/core/detect/filters/volume/structure_detection.py +5 -0
- cellfinder/core/detect/filters/volume/structure_splitting.py +3 -2
- cellfinder/core/detect/filters/volume/volume_filter.py +1 -1
- cellfinder/core/download/download.py +2 -1
- cellfinder/core/main.py +162 -30
- cellfinder/core/tools/threading.py +4 -3
- cellfinder/core/tools/tools.py +1 -1
- cellfinder/core/train/{train_yml.py → train_yaml.py} +6 -15
- cellfinder/napari/curation.py +72 -21
- cellfinder/napari/detect/detect.py +87 -28
- cellfinder/napari/detect/detect_containers.py +41 -9
- cellfinder/napari/detect/thread_worker.py +26 -16
- cellfinder/napari/input_container.py +14 -4
- cellfinder/napari/train/train.py +5 -9
- cellfinder/napari/train/train_containers.py +2 -4
- cellfinder/napari/utils.py +6 -1
- {cellfinder-1.4.1.dist-info → cellfinder-1.9.0.dist-info}/METADATA +16 -12
- {cellfinder-1.4.1.dist-info → cellfinder-1.9.0.dist-info}/RECORD +29 -29
- {cellfinder-1.4.1.dist-info → cellfinder-1.9.0.dist-info}/WHEEL +1 -1
- {cellfinder-1.4.1.dist-info → cellfinder-1.9.0.dist-info}/entry_points.txt +1 -1
- {cellfinder-1.4.1.dist-info → cellfinder-1.9.0.dist-info/licenses}/LICENSE +0 -0
- {cellfinder-1.4.1.dist-info → cellfinder-1.9.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
|
|
3
|
+
from cellfinder.core import logger
|
|
4
|
+
|
|
3
5
|
BRAINGLOBE_WORKFLOWS = "https://github.com/brainglobe/brainglobe-workflows"
|
|
4
6
|
NEW_NAME = "brainmapper"
|
|
5
7
|
BLOG_POST = "https://brainglobe.info/blog/version1/core_and_napari_merge.html"
|
|
@@ -36,7 +38,7 @@ def cli_catch() -> None:
|
|
|
36
38
|
),
|
|
37
39
|
)
|
|
38
40
|
|
|
39
|
-
|
|
41
|
+
logger.warning(
|
|
40
42
|
"Hey, it looks like you're trying to run the old command-line tool.",
|
|
41
43
|
"This workflow has been renamed and moved -",
|
|
42
44
|
" you can now find it in the brainglobe-workflows package:\n",
|
|
@@ -11,7 +11,7 @@ from brainglobe_utils.general.system import get_num_processes
|
|
|
11
11
|
from cellfinder.core import logger, types
|
|
12
12
|
from cellfinder.core.classify.cube_generator import CubeGeneratorFromFile
|
|
13
13
|
from cellfinder.core.classify.tools import get_model
|
|
14
|
-
from cellfinder.core.train.
|
|
14
|
+
from cellfinder.core.train.train_yaml import depth_type, models
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
def main(
|
|
@@ -19,8 +19,8 @@ def main(
|
|
|
19
19
|
signal_array: types.array,
|
|
20
20
|
background_array: types.array,
|
|
21
21
|
n_free_cpus: int,
|
|
22
|
-
voxel_sizes: Tuple[
|
|
23
|
-
network_voxel_sizes: Tuple[
|
|
22
|
+
voxel_sizes: Tuple[float, float, float],
|
|
23
|
+
network_voxel_sizes: Tuple[float, float, float],
|
|
24
24
|
batch_size: int,
|
|
25
25
|
cube_height: int,
|
|
26
26
|
cube_width: int,
|
|
@@ -29,12 +29,58 @@ def main(
|
|
|
29
29
|
model_weights: Optional[os.PathLike],
|
|
30
30
|
network_depth: depth_type,
|
|
31
31
|
max_workers: int = 3,
|
|
32
|
+
pin_memory: bool = False,
|
|
32
33
|
*,
|
|
33
34
|
callback: Optional[Callable[[int], None]] = None,
|
|
34
35
|
) -> List[Cell]:
|
|
35
36
|
"""
|
|
36
37
|
Parameters
|
|
37
38
|
----------
|
|
39
|
+
|
|
40
|
+
points: List of Cell objects
|
|
41
|
+
The potential cells to classify.
|
|
42
|
+
signal_array : numpy.ndarray or dask array
|
|
43
|
+
3D array representing the signal data in z, y, x order.
|
|
44
|
+
background_array : numpy.ndarray or dask array
|
|
45
|
+
3D array representing the signal data in z, y, x order.
|
|
46
|
+
n_free_cpus : int
|
|
47
|
+
How many CPU cores to leave free.
|
|
48
|
+
voxel_sizes : 3-tuple of floats
|
|
49
|
+
Size of your voxels in the z, y, and x dimensions.
|
|
50
|
+
network_voxel_sizes : 3-tuple of floats
|
|
51
|
+
Size of the pre-trained network's voxels in the z, y, and x dimensions.
|
|
52
|
+
batch_size : int
|
|
53
|
+
How many potential cells to classify at one time. The GPU/CPU
|
|
54
|
+
memory must be able to contain at once this many data cubes for
|
|
55
|
+
the models. For performance-critical applications, tune to maximize
|
|
56
|
+
memory usage without running out. Check your GPU/CPU memory to verify
|
|
57
|
+
it's not full.
|
|
58
|
+
cube_height: int
|
|
59
|
+
The height of the data cube centered on the cell used for
|
|
60
|
+
classification. Defaults to `50`.
|
|
61
|
+
cube_width: int
|
|
62
|
+
The width of the data cube centered on the cell used for
|
|
63
|
+
classification. Defaults to `50`.
|
|
64
|
+
cube_depth: int
|
|
65
|
+
The depth of the data cube centered on the cell used for
|
|
66
|
+
classification. Defaults to `20`.
|
|
67
|
+
trained_model : Optional[Path]
|
|
68
|
+
Trained model file path (home directory (default) -> pretrained
|
|
69
|
+
weights).
|
|
70
|
+
model_weights : Optional[Path]
|
|
71
|
+
Model weights path (home directory (default) -> pretrained
|
|
72
|
+
weights).
|
|
73
|
+
network_depth: str
|
|
74
|
+
The network depth to use during classification. Defaults to `"50"`.
|
|
75
|
+
max_workers: int
|
|
76
|
+
The number of sub-processes to use for data loading / processing.
|
|
77
|
+
Defaults to 8.
|
|
78
|
+
pin_memory: bool
|
|
79
|
+
Pins data to be sent to the GPU to the CPU memory. This allows faster
|
|
80
|
+
GPU data speeds, but can only be used if the data used by the GPU can
|
|
81
|
+
stay in the CPU RAM while the GPU uses it. I.e. there's enough RAM.
|
|
82
|
+
Otherwise, if there's a risk of the RAM being paged, it shouldn't be
|
|
83
|
+
used. Defaults to False.
|
|
38
84
|
callback : Callable[int], optional
|
|
39
85
|
A callback function that is called during classification. Called with
|
|
40
86
|
the batch number once that batch has been classified.
|
|
@@ -70,7 +116,7 @@ def main(
|
|
|
70
116
|
)
|
|
71
117
|
|
|
72
118
|
if trained_model and Path(trained_model).suffix == ".h5":
|
|
73
|
-
|
|
119
|
+
logger.warning(
|
|
74
120
|
"Weights provided in place of the model, "
|
|
75
121
|
"loading weights into default model."
|
|
76
122
|
)
|
|
@@ -103,7 +149,7 @@ def main(
|
|
|
103
149
|
points_list.append(cell)
|
|
104
150
|
|
|
105
151
|
time_elapsed = datetime.now() - start_time
|
|
106
|
-
|
|
152
|
+
logger.info(
|
|
107
153
|
"Classfication complete - all points done in : {}".format(time_elapsed)
|
|
108
154
|
)
|
|
109
155
|
|
|
@@ -47,9 +47,19 @@ def get_model(
|
|
|
47
47
|
f"Setting model weights according to: {model_weights}",
|
|
48
48
|
)
|
|
49
49
|
if model_weights is None:
|
|
50
|
-
raise OSError(
|
|
51
|
-
|
|
52
|
-
|
|
50
|
+
raise OSError(
|
|
51
|
+
"`model_weights` must be provided for inference "
|
|
52
|
+
"or continued training."
|
|
53
|
+
)
|
|
54
|
+
try:
|
|
55
|
+
model.load_weights(model_weights)
|
|
56
|
+
except (OSError, ValueError) as e:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Error loading weights: {model_weights}.\n"
|
|
59
|
+
"Provided weights don't match the model architecture.\n"
|
|
60
|
+
) from e
|
|
61
|
+
|
|
62
|
+
return model
|
|
53
63
|
|
|
54
64
|
|
|
55
65
|
def make_lists(
|
cellfinder/core/detect/detect.py
CHANGED
|
@@ -48,12 +48,14 @@ def main(
|
|
|
48
48
|
save_planes: bool = False,
|
|
49
49
|
plane_directory: Optional[str] = None,
|
|
50
50
|
batch_size: Optional[int] = None,
|
|
51
|
-
torch_device: str =
|
|
52
|
-
|
|
53
|
-
split_ball_xy_size:
|
|
54
|
-
split_ball_z_size:
|
|
51
|
+
torch_device: Optional[str] = None,
|
|
52
|
+
pin_memory: bool = False,
|
|
53
|
+
split_ball_xy_size: float = 6,
|
|
54
|
+
split_ball_z_size: float = 15,
|
|
55
55
|
split_ball_overlap_fraction: float = 0.8,
|
|
56
|
-
|
|
56
|
+
n_splitting_iter: int = 10,
|
|
57
|
+
n_sds_above_mean_tiled_thresh: float = 10,
|
|
58
|
+
tiled_thresh_tile_size: float | None = None,
|
|
57
59
|
*,
|
|
58
60
|
callback: Optional[Callable[[int], None]] = None,
|
|
59
61
|
) -> List[Cell]:
|
|
@@ -62,69 +64,94 @@ def main(
|
|
|
62
64
|
|
|
63
65
|
Parameters
|
|
64
66
|
----------
|
|
65
|
-
signal_array : numpy.ndarray
|
|
66
|
-
3D array representing the signal data.
|
|
67
|
-
|
|
67
|
+
signal_array : numpy.ndarray or dask array
|
|
68
|
+
3D array representing the signal data in z, y, x order.
|
|
68
69
|
start_plane : int
|
|
69
|
-
|
|
70
|
-
|
|
70
|
+
First plane index to process (inclusive, to process a subset of the
|
|
71
|
+
data).
|
|
71
72
|
end_plane : int
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
voxel_sizes :
|
|
75
|
-
|
|
76
|
-
|
|
73
|
+
Last plane index to process (exclusive, to process a subset of the
|
|
74
|
+
data).
|
|
75
|
+
voxel_sizes : 3-tuple of floats
|
|
76
|
+
Size of your voxels in the z, y, and x dimensions (microns).
|
|
77
77
|
soma_diameter : float
|
|
78
|
-
|
|
79
|
-
|
|
78
|
+
The expected in-plane (xy) soma diameter (microns).
|
|
80
79
|
max_cluster_size : float
|
|
81
|
-
|
|
82
|
-
|
|
80
|
+
Largest detected cell cluster (in cubic um) where splitting
|
|
81
|
+
should be attempted. Clusters above this size will be labeled
|
|
82
|
+
as artifacts.
|
|
83
83
|
ball_xy_size : float
|
|
84
|
-
|
|
85
|
-
|
|
84
|
+
3d filter's in-plane (xy) filter ball size (microns).
|
|
86
85
|
ball_z_size : float
|
|
87
|
-
|
|
88
|
-
|
|
86
|
+
3d filter's axial (z) filter ball size (microns).
|
|
89
87
|
ball_overlap_fraction : float
|
|
90
|
-
|
|
91
|
-
|
|
88
|
+
3d filter's fraction of the ball filter needed to be filled by
|
|
89
|
+
foreground voxels, centered on a voxel, to retain the voxel.
|
|
92
90
|
soma_spread_factor : float
|
|
93
|
-
|
|
94
|
-
|
|
91
|
+
Cell spread factor for determining the largest cell volume before
|
|
92
|
+
splitting up cell clusters. Structures with spherical volume of
|
|
93
|
+
diameter `soma_spread_factor * soma_diameter` or less will not be
|
|
94
|
+
split.
|
|
95
95
|
n_free_cpus : int
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
How many CPU cores to leave free.
|
|
98
97
|
log_sigma_size : float
|
|
99
|
-
|
|
100
|
-
|
|
98
|
+
Gaussian filter width (as a fraction of soma diameter) used during
|
|
99
|
+
2d in-plane Laplacian of Gaussian filtering.
|
|
101
100
|
n_sds_above_mean_thresh : float
|
|
102
|
-
|
|
103
|
-
|
|
101
|
+
Per-plane intensity threshold (the number of standard deviations
|
|
102
|
+
above the mean) of the filtered 2d planes used to mark pixels as
|
|
103
|
+
foreground or background.
|
|
104
104
|
outlier_keep : bool, optional
|
|
105
105
|
Whether to keep outliers during detection. Defaults to False.
|
|
106
|
-
|
|
107
106
|
artifact_keep : bool, optional
|
|
108
107
|
Whether to keep artifacts during detection. Defaults to False.
|
|
109
|
-
|
|
110
108
|
save_planes : bool, optional
|
|
111
109
|
Whether to save the planes during detection. Defaults to False.
|
|
112
|
-
|
|
113
110
|
plane_directory : str, optional
|
|
114
111
|
Directory path to save the planes. Defaults to None.
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
The
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
becomes slower.
|
|
122
|
-
|
|
112
|
+
batch_size: int
|
|
113
|
+
The number of planes of the original data volume to process at
|
|
114
|
+
once. The GPU/CPU memory must be able to contain this many planes
|
|
115
|
+
for all the filters. For performance-critical applications, tune to
|
|
116
|
+
maximize memory usage without running out. Check your GPU/CPU memory
|
|
117
|
+
to verify it's not full.
|
|
123
118
|
torch_device : str, optional
|
|
124
|
-
The device on which to run the computation.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
119
|
+
The device on which to run the computation. If not specified (None),
|
|
120
|
+
"cuda" will be used if a GPU is available, otherwise "cpu".
|
|
121
|
+
You can also manually specify "cuda" or "cpu".
|
|
122
|
+
pin_memory: bool
|
|
123
|
+
Pins data to be sent to the GPU to the CPU memory. This allows faster
|
|
124
|
+
GPU data speeds, but can only be used if the data used by the GPU can
|
|
125
|
+
stay in the CPU RAM while the GPU uses it. I.e. there's enough RAM.
|
|
126
|
+
Otherwise, if there's a risk of the RAM being paged, it shouldn't be
|
|
127
|
+
used. Defaults to False.
|
|
128
|
+
split_ball_xy_size: float
|
|
129
|
+
Similar to `ball_xy_size`, except the value to use for the 3d
|
|
130
|
+
filter during cluster splitting.
|
|
131
|
+
split_ball_z_size: float
|
|
132
|
+
Similar to `ball_z_size`, except the value to use for the 3d filter
|
|
133
|
+
during cluster splitting.
|
|
134
|
+
split_ball_overlap_fraction: float
|
|
135
|
+
Similar to `ball_overlap_fraction`, except the value to use for the
|
|
136
|
+
3d filter during cluster splitting.
|
|
137
|
+
n_splitting_iter: int
|
|
138
|
+
The number of iterations to run the 3d filtering on a cluster. Each
|
|
139
|
+
iteration reduces the cluster size by the voxels not retained in
|
|
140
|
+
the previous iteration.
|
|
141
|
+
n_sds_above_mean_tiled_thresh : float
|
|
142
|
+
Per-plane, per-tile intensity threshold (the number of standard
|
|
143
|
+
deviations above the mean) for the filtered 2d planes used to mark
|
|
144
|
+
pixels as foreground or background. When used, (tile size is not zero)
|
|
145
|
+
a pixel is marked as foreground if its intensity is above both the
|
|
146
|
+
per-plane and per-tile threshold. I.e. it's above the set number of
|
|
147
|
+
standard deviations of the per-plane average and of the per-plane
|
|
148
|
+
per-tile average for the tile that contains it.
|
|
149
|
+
tiled_thresh_tile_size : float
|
|
150
|
+
The tile size used to tile the x, y plane to calculate the local
|
|
151
|
+
average intensity for the tiled threshold. The value is multiplied
|
|
152
|
+
by soma diameter (i.e. 1 means one soma diameter). If zero or None, the
|
|
153
|
+
tiled threshold is disabled and only the per-plane threshold is used.
|
|
154
|
+
Tiling is done with 50% overlap when striding.
|
|
128
155
|
callback : Callable[int], optional
|
|
129
156
|
A callback function that is called every time a plane has finished
|
|
130
157
|
being processed. Called with the plane number that has finished.
|
|
@@ -132,9 +159,11 @@ def main(
|
|
|
132
159
|
Returns
|
|
133
160
|
-------
|
|
134
161
|
List[Cell]
|
|
135
|
-
List of detected
|
|
162
|
+
List of detected cell candidates.
|
|
136
163
|
"""
|
|
137
164
|
start_time = datetime.now()
|
|
165
|
+
if torch_device is None:
|
|
166
|
+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
138
167
|
if batch_size is None:
|
|
139
168
|
if torch_device == "cpu":
|
|
140
169
|
batch_size = 4
|
|
@@ -155,6 +184,12 @@ def main(
|
|
|
155
184
|
end_plane = min(len(signal_array), end_plane)
|
|
156
185
|
|
|
157
186
|
torch_device = torch_device.lower()
|
|
187
|
+
# Use SciPy filtering on CPU (better performance); use PyTorch on GPU
|
|
188
|
+
if torch_device != "cuda":
|
|
189
|
+
use_scipy = True
|
|
190
|
+
else:
|
|
191
|
+
use_scipy = False
|
|
192
|
+
|
|
158
193
|
batch_size = max(batch_size, 1)
|
|
159
194
|
# brainmapper can pass them in as str
|
|
160
195
|
voxel_sizes = list(map(float, voxel_sizes))
|
|
@@ -174,25 +209,24 @@ def main(
|
|
|
174
209
|
ball_overlap_fraction=ball_overlap_fraction,
|
|
175
210
|
log_sigma_size=log_sigma_size,
|
|
176
211
|
n_sds_above_mean_thresh=n_sds_above_mean_thresh,
|
|
212
|
+
n_sds_above_mean_tiled_thresh=n_sds_above_mean_tiled_thresh,
|
|
213
|
+
tiled_thresh_tile_size=tiled_thresh_tile_size,
|
|
177
214
|
outlier_keep=outlier_keep,
|
|
178
215
|
artifact_keep=artifact_keep,
|
|
179
216
|
save_planes=save_planes,
|
|
180
217
|
plane_directory=plane_directory,
|
|
181
218
|
batch_size=batch_size,
|
|
182
219
|
torch_device=torch_device,
|
|
220
|
+
pin_memory=pin_memory,
|
|
221
|
+
n_splitting_iter=n_splitting_iter,
|
|
183
222
|
)
|
|
184
223
|
|
|
185
224
|
# replicate the settings specific to splitting, before we access anything
|
|
186
225
|
# of the original settings, causing cached properties
|
|
187
226
|
kwargs = dataclasses.asdict(settings)
|
|
188
|
-
kwargs["ball_z_size_um"] = split_ball_z_size
|
|
189
|
-
kwargs["ball_xy_size_um"] =
|
|
190
|
-
split_ball_xy_size * settings.in_plane_pixel_size
|
|
191
|
-
)
|
|
227
|
+
kwargs["ball_z_size_um"] = split_ball_z_size
|
|
228
|
+
kwargs["ball_xy_size_um"] = split_ball_xy_size
|
|
192
229
|
kwargs["ball_overlap_fraction"] = split_ball_overlap_fraction
|
|
193
|
-
kwargs["soma_diameter_um"] = (
|
|
194
|
-
split_soma_diameter * settings.in_plane_pixel_size
|
|
195
|
-
)
|
|
196
230
|
# always run on cpu because copying to gpu overhead is likely slower than
|
|
197
231
|
# any benefit for detection on smallish volumes
|
|
198
232
|
kwargs["torch_device"] = "cpu"
|
|
@@ -212,7 +246,9 @@ def main(
|
|
|
212
246
|
plane_shape=settings.plane_shape,
|
|
213
247
|
clipping_value=settings.clipping_value,
|
|
214
248
|
threshold_value=settings.threshold_value,
|
|
215
|
-
n_sds_above_mean_thresh=n_sds_above_mean_thresh,
|
|
249
|
+
n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh,
|
|
250
|
+
n_sds_above_mean_tiled_thresh=settings.n_sds_above_mean_tiled_thresh,
|
|
251
|
+
tiled_thresh_tile_size=settings.tiled_thresh_tile_size,
|
|
216
252
|
log_sigma_size=log_sigma_size,
|
|
217
253
|
soma_diameter=settings.soma_diameter,
|
|
218
254
|
torch_device=torch_device,
|
|
@@ -231,6 +267,5 @@ def main(
|
|
|
231
267
|
|
|
232
268
|
time_elapsed = datetime.now() - start_time
|
|
233
269
|
s = f"Detection complete. Found {len(cells)} cells in {time_elapsed}"
|
|
234
|
-
logger.
|
|
235
|
-
print(s)
|
|
270
|
+
logger.info(s)
|
|
236
271
|
return cells
|
|
@@ -1,13 +1,12 @@
|
|
|
1
|
-
from dataclasses import dataclass, field
|
|
2
1
|
from typing import Tuple
|
|
3
2
|
|
|
4
3
|
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
5
|
|
|
6
6
|
from cellfinder.core.detect.filters.plane.classical_filter import PeakEnhancer
|
|
7
7
|
from cellfinder.core.detect.filters.plane.tile_walker import TileWalker
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
@dataclass
|
|
11
10
|
class TileProcessor:
|
|
12
11
|
"""
|
|
13
12
|
Processor that filters each plane to highlight the peaks and also
|
|
@@ -39,7 +38,7 @@ class TileProcessor:
|
|
|
39
38
|
Number of standard deviations above the mean threshold to use for
|
|
40
39
|
determining whether a voxel is bright.
|
|
41
40
|
log_sigma_size : float
|
|
42
|
-
Size of the sigma for the
|
|
41
|
+
Size of the Gaussian sigma for the Laplacian of Gaussian filtering.
|
|
43
42
|
soma_diameter : float
|
|
44
43
|
Diameter of the soma in voxels.
|
|
45
44
|
torch_device: str
|
|
@@ -63,12 +62,22 @@ class TileProcessor:
|
|
|
63
62
|
# voxels who are this many std above mean or more are set to
|
|
64
63
|
# threshold_value
|
|
65
64
|
n_sds_above_mean_thresh: float
|
|
65
|
+
# If used, voxels who are this many or more std above mean of the
|
|
66
|
+
# containing tile as well as above n_sds_above_mean_thresh for the plane
|
|
67
|
+
# average are set to threshold_value.
|
|
68
|
+
n_sds_above_mean_tiled_thresh: float
|
|
69
|
+
# the tile size, in pixels, that will be used to tile the x, y plane when
|
|
70
|
+
# we calculate the per-tile mean / std for use with
|
|
71
|
+
# n_sds_above_mean_tiled_thresh. We use 50% overlap when tiling.
|
|
72
|
+
local_threshold_tile_size_px: int = 0
|
|
73
|
+
# the torch device name
|
|
74
|
+
torch_device: str = ""
|
|
66
75
|
|
|
67
76
|
# filter that finds the peaks in the planes
|
|
68
|
-
peak_enhancer: PeakEnhancer =
|
|
77
|
+
peak_enhancer: PeakEnhancer = None
|
|
69
78
|
# generates tiles of the planes, with each tile marked as being inside
|
|
70
79
|
# or outside the brain based on brightness
|
|
71
|
-
tile_walker: TileWalker =
|
|
80
|
+
tile_walker: TileWalker = None
|
|
72
81
|
|
|
73
82
|
def __init__(
|
|
74
83
|
self,
|
|
@@ -76,6 +85,8 @@ class TileProcessor:
|
|
|
76
85
|
clipping_value: int,
|
|
77
86
|
threshold_value: int,
|
|
78
87
|
n_sds_above_mean_thresh: float,
|
|
88
|
+
n_sds_above_mean_tiled_thresh: float,
|
|
89
|
+
tiled_thresh_tile_size: float | None,
|
|
79
90
|
log_sigma_size: float,
|
|
80
91
|
soma_diameter: int,
|
|
81
92
|
torch_device: str,
|
|
@@ -85,6 +96,12 @@ class TileProcessor:
|
|
|
85
96
|
self.clipping_value = clipping_value
|
|
86
97
|
self.threshold_value = threshold_value
|
|
87
98
|
self.n_sds_above_mean_thresh = n_sds_above_mean_thresh
|
|
99
|
+
self.n_sds_above_mean_tiled_thresh = n_sds_above_mean_tiled_thresh
|
|
100
|
+
if tiled_thresh_tile_size:
|
|
101
|
+
self.local_threshold_tile_size_px = int(
|
|
102
|
+
round(soma_diameter * tiled_thresh_tile_size)
|
|
103
|
+
)
|
|
104
|
+
self.torch_device = torch_device
|
|
88
105
|
|
|
89
106
|
laplace_gaussian_sigma = log_sigma_size * soma_diameter
|
|
90
107
|
self.peak_enhancer = PeakEnhancer(
|
|
@@ -131,7 +148,10 @@ class TileProcessor:
|
|
|
131
148
|
planes,
|
|
132
149
|
enhanced_planes,
|
|
133
150
|
self.n_sds_above_mean_thresh,
|
|
151
|
+
self.n_sds_above_mean_tiled_thresh,
|
|
152
|
+
self.local_threshold_tile_size_px,
|
|
134
153
|
self.threshold_value,
|
|
154
|
+
self.torch_device,
|
|
135
155
|
)
|
|
136
156
|
|
|
137
157
|
return planes, inside_brain_tiles
|
|
@@ -145,21 +165,98 @@ def _threshold_planes(
|
|
|
145
165
|
planes: torch.Tensor,
|
|
146
166
|
enhanced_planes: torch.Tensor,
|
|
147
167
|
n_sds_above_mean_thresh: float,
|
|
168
|
+
n_sds_above_mean_tiled_thresh: float,
|
|
169
|
+
local_threshold_tile_size_px: int,
|
|
148
170
|
threshold_value: int,
|
|
171
|
+
torch_device: str,
|
|
149
172
|
) -> None:
|
|
150
173
|
"""
|
|
151
174
|
Sets each plane (in-place) to threshold_value, where the corresponding
|
|
152
175
|
enhanced_plane > mean + n_sds_above_mean_thresh*std. Each plane will be
|
|
153
176
|
set to zero elsewhere.
|
|
154
177
|
"""
|
|
155
|
-
|
|
178
|
+
z, y, x = enhanced_planes.shape
|
|
156
179
|
|
|
180
|
+
# ---- get per-plane global threshold ----
|
|
181
|
+
planes_1d = enhanced_planes.view(z, -1)
|
|
157
182
|
# add back last dim
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
183
|
+
std, mean = torch.std_mean(planes_1d, dim=1, keepdim=True)
|
|
184
|
+
threshold = mean.unsqueeze(2) + n_sds_above_mean_thresh * std.unsqueeze(2)
|
|
185
|
+
above_global = enhanced_planes > threshold
|
|
186
|
+
|
|
187
|
+
# ---- calculate the local tiled threshold ----
|
|
188
|
+
# we do 50% overlap so there's no jumps at boundaries
|
|
189
|
+
stride = local_threshold_tile_size_px // 2
|
|
190
|
+
# make tile even for ease of computation
|
|
191
|
+
tile_size = stride * 2
|
|
192
|
+
# Due to 50% overlap, to get tiles we move the tile by half tile (stride).
|
|
193
|
+
# Total moves will be y // stride - 2 (we start already with mask on first
|
|
194
|
+
# tile). So add back 1 for the first tile. Partial tiles are dropped
|
|
195
|
+
n_y_tiles = max(y // stride - 1, 1) if stride else 1
|
|
196
|
+
n_x_tiles = max(x // stride - 1, 1) if stride else 1
|
|
197
|
+
do_tile_y = n_y_tiles >= 2
|
|
198
|
+
do_tile_x = n_x_tiles >= 2
|
|
199
|
+
# we want at least one axis to have at least two tiles
|
|
200
|
+
if local_threshold_tile_size_px >= 2 and (do_tile_y or do_tile_x):
|
|
201
|
+
# num edge pixels dropped b/c moving by stride would move tile off edge
|
|
202
|
+
y_rem = y % stride
|
|
203
|
+
x_rem = x % stride
|
|
204
|
+
enhanced_planes_raw = enhanced_planes
|
|
205
|
+
if do_tile_y:
|
|
206
|
+
enhanced_planes = enhanced_planes[:, y_rem // 2 :, :]
|
|
207
|
+
if do_tile_x:
|
|
208
|
+
enhanced_planes = enhanced_planes[:, :, x_rem // 2 :]
|
|
209
|
+
|
|
210
|
+
# add empty channel dim after z "batch" dim -> zcyx
|
|
211
|
+
enhanced_planes = enhanced_planes.unsqueeze(1)
|
|
212
|
+
# unfold makes it 3 dim, z, M, L. L is number of tiles, M is tile area
|
|
213
|
+
unfolded = F.unfold(
|
|
214
|
+
enhanced_planes,
|
|
215
|
+
(tile_size if do_tile_y else y, tile_size if do_tile_x else x),
|
|
216
|
+
stride=stride,
|
|
217
|
+
)
|
|
218
|
+
# average the tile areas, for each tile
|
|
219
|
+
std, mean = torch.std_mean(unfolded, dim=1, keepdim=True)
|
|
220
|
+
threshold = mean + n_sds_above_mean_tiled_thresh * std
|
|
221
|
+
|
|
222
|
+
# reshape it back into Y by X tiles, instead of YX being one dim
|
|
223
|
+
threshold = threshold.reshape((z, n_y_tiles, n_x_tiles))
|
|
224
|
+
|
|
225
|
+
# we need total size of n_tiles * stride + stride + rem for the
|
|
226
|
+
# original size. So we add 2 strides and then chop off the excess above
|
|
227
|
+
# rem. We center it because of 50% overlap, the first tile is actually
|
|
228
|
+
# centered in between the first two strides
|
|
229
|
+
offsets = [(0, y), (0, x)]
|
|
230
|
+
for dim, do_tile, n_tiles, n, rem in [
|
|
231
|
+
(1, do_tile_y, n_y_tiles, y, y_rem),
|
|
232
|
+
(2, do_tile_x, n_x_tiles, x, x_rem),
|
|
233
|
+
]:
|
|
234
|
+
if do_tile:
|
|
235
|
+
repeats = (
|
|
236
|
+
torch.ones(n_tiles, dtype=torch.int, device=torch_device)
|
|
237
|
+
* stride
|
|
238
|
+
)
|
|
239
|
+
# add total of 2 additional strides
|
|
240
|
+
repeats[0] = 2 * stride
|
|
241
|
+
repeats[-1] = 2 * stride
|
|
242
|
+
output_size = (n_tiles + 2) * stride
|
|
243
|
+
|
|
244
|
+
threshold = threshold.repeat_interleave(
|
|
245
|
+
repeats, dim=dim, output_size=output_size
|
|
246
|
+
)
|
|
247
|
+
# drop the excess we gained from padding rem to whole stride
|
|
248
|
+
offset = (stride - rem) // 2
|
|
249
|
+
offsets[dim - 1] = offset, n + offset
|
|
250
|
+
|
|
251
|
+
# can't use slice(...) objects in jit code so use actual indices
|
|
252
|
+
(a, b), (c, d) = offsets
|
|
253
|
+
threshold = threshold[:, a:b, c:d]
|
|
254
|
+
|
|
255
|
+
above_local = enhanced_planes_raw > threshold
|
|
256
|
+
above = torch.logical_and(above_global, above_local)
|
|
257
|
+
else:
|
|
258
|
+
above = above_global
|
|
161
259
|
|
|
162
|
-
above = enhanced_planes > threshold
|
|
163
260
|
planes[above] = threshold_value
|
|
164
261
|
# subsequent steps only care about the values that are set to threshold or
|
|
165
262
|
# above in planes. We set values in *planes* to threshold based on the
|
|
@@ -80,23 +80,28 @@ class DetectionSettings:
|
|
|
80
80
|
|
|
81
81
|
voxel_sizes: Tuple[float, float, float] = (1.0, 1.0, 1.0)
|
|
82
82
|
"""
|
|
83
|
-
Tuple of voxel sizes in each dimension (z, y, x). We use this
|
|
84
|
-
from `um` to pixel sizes.
|
|
83
|
+
Tuple of voxel sizes (microns) in each dimension (z, y, x). We use this
|
|
84
|
+
to convert from `um` to pixel sizes.
|
|
85
85
|
"""
|
|
86
86
|
|
|
87
87
|
soma_spread_factor: float = 1.4
|
|
88
|
-
"""
|
|
88
|
+
"""
|
|
89
|
+
Cell spread factor for determining the largest cell volume before
|
|
90
|
+
splitting up cell clusters. Structures with spherical volume of
|
|
91
|
+
diameter `soma_spread_factor * soma_diameter` or less will not be
|
|
92
|
+
split.
|
|
93
|
+
"""
|
|
89
94
|
|
|
90
95
|
soma_diameter_um: float = 16
|
|
91
96
|
"""
|
|
92
|
-
Diameter of a typical soma in
|
|
93
|
-
split.
|
|
97
|
+
Diameter of a typical soma in-plane (xy) in microns.
|
|
94
98
|
"""
|
|
95
99
|
|
|
96
100
|
max_cluster_size_um3: float = 100_000
|
|
97
101
|
"""
|
|
98
|
-
|
|
99
|
-
|
|
102
|
+
Largest detected cell cluster (in cubic um) where splitting
|
|
103
|
+
should be attempted. Clusters above this size will be labeled
|
|
104
|
+
as artifacts.
|
|
100
105
|
"""
|
|
101
106
|
|
|
102
107
|
ball_xy_size_um: float = 6
|
|
@@ -116,17 +121,41 @@ class DetectionSettings:
|
|
|
116
121
|
|
|
117
122
|
ball_overlap_fraction: float = 0.6
|
|
118
123
|
"""
|
|
119
|
-
Fraction of
|
|
120
|
-
|
|
124
|
+
Fraction of the 3d ball filter needed to be filled by foreground voxels,
|
|
125
|
+
centered on a voxel, to retain the voxel.
|
|
121
126
|
"""
|
|
122
127
|
|
|
123
128
|
log_sigma_size: float = 0.2
|
|
124
|
-
"""
|
|
129
|
+
"""
|
|
130
|
+
Gaussian filter width (as a fraction of soma diameter) used during
|
|
131
|
+
2d in-plane Laplacian of Gaussian filtering.
|
|
132
|
+
"""
|
|
125
133
|
|
|
126
134
|
n_sds_above_mean_thresh: float = 10
|
|
127
135
|
"""
|
|
128
|
-
|
|
129
|
-
|
|
136
|
+
Per-plane intensity threshold (the number of standard deviations
|
|
137
|
+
above the mean) of the 2d filtered planes used to mark pixels as
|
|
138
|
+
foreground or background.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
n_sds_above_mean_tiled_thresh: float = 10
|
|
142
|
+
"""
|
|
143
|
+
Per-plane, per-tile intensity threshold (the number of standard deviations
|
|
144
|
+
above the mean) for the filtered 2d planes used to mark pixels as
|
|
145
|
+
foreground or background. When used, (tile size is not zero) a pixel is
|
|
146
|
+
marked as foreground if its intensity is above both the per-plane and
|
|
147
|
+
per-tile threshold. I.e. it's above the set number of standard deviations
|
|
148
|
+
of the per-plane average and of the per-plane per-tile average for the tile
|
|
149
|
+
that contains it.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
tiled_thresh_tile_size: float | None = None
|
|
153
|
+
"""
|
|
154
|
+
The tile size used to tile the x, y plane to calculate the local average
|
|
155
|
+
intensity for the tiled threshold. The value is multiplied by soma
|
|
156
|
+
diameter (i.e. 1 means one soma diameter). If zero or None, the tiled
|
|
157
|
+
threshold is disabled and only the per-plane threshold is used. Tiling is
|
|
158
|
+
done with 50% overlap when striding.
|
|
130
159
|
"""
|
|
131
160
|
|
|
132
161
|
outlier_keep: bool = False
|
|
@@ -180,6 +209,14 @@ class DetectionSettings:
|
|
|
180
209
|
to run on the first GPU.
|
|
181
210
|
"""
|
|
182
211
|
|
|
212
|
+
pin_memory: bool = False
|
|
213
|
+
"""
|
|
214
|
+
Pins data to be sent to the GPU to the CPU memory. This allows faster GPU
|
|
215
|
+
data speeds, but can only be used if the data used by the GPU can stay in
|
|
216
|
+
the CPU RAM while the GPU uses it. I.e. there's enough RAM. Otherwise, if
|
|
217
|
+
there's a risk of the RAM being paged, it shouldn't be used.
|
|
218
|
+
"""
|
|
219
|
+
|
|
183
220
|
n_free_cpus: int = 2
|
|
184
221
|
"""
|
|
185
222
|
Number of free CPU cores to keep available and not use during parallel
|
|
@@ -191,6 +228,8 @@ class DetectionSettings:
|
|
|
191
228
|
"""
|
|
192
229
|
During the structure splitting phase we iteratively shrink the bright areas
|
|
193
230
|
and re-filter with the 3d filter. This is the number of iterations to do.
|
|
231
|
+
Each iteration reduces the cluster size by the voxels not retained in the
|
|
232
|
+
previous iteration.
|
|
194
233
|
|
|
195
234
|
This is a maximum because we also stop if there are no more structures left
|
|
196
235
|
during any iteration.
|