cellfinder 1.3.3__py3-none-any.whl → 1.4.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cellfinder/core/classify/classify.py +1 -1
- cellfinder/core/detect/detect.py +118 -183
- cellfinder/core/detect/filters/plane/classical_filter.py +339 -37
- cellfinder/core/detect/filters/plane/plane_filter.py +137 -55
- cellfinder/core/detect/filters/plane/tile_walker.py +126 -60
- cellfinder/core/detect/filters/setup_filters.py +422 -65
- cellfinder/core/detect/filters/volume/ball_filter.py +313 -315
- cellfinder/core/detect/filters/volume/structure_detection.py +73 -35
- cellfinder/core/detect/filters/volume/structure_splitting.py +160 -96
- cellfinder/core/detect/filters/volume/volume_filter.py +444 -123
- cellfinder/core/main.py +6 -2
- cellfinder/core/tools/IO.py +45 -0
- cellfinder/core/tools/threading.py +380 -0
- cellfinder/core/tools/tools.py +128 -6
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0a0.dist-info}/METADATA +3 -2
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0a0.dist-info}/RECORD +20 -18
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0a0.dist-info}/WHEEL +1 -1
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0a0.dist-info}/LICENSE +0 -0
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0a0.dist-info}/entry_points.txt +0 -0
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0a0.dist-info}/top_level.txt +0 -0
|
@@ -1,29 +1,35 @@
|
|
|
1
|
+
import math
|
|
1
2
|
from functools import lru_cache
|
|
3
|
+
from typing import Optional
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
from numba.experimental import jitclass
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
7
8
|
|
|
8
9
|
from cellfinder.core.tools.array_operations import bin_mean_3d
|
|
9
10
|
from cellfinder.core.tools.geometry import make_sphere
|
|
10
11
|
|
|
11
|
-
DEBUG = False
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
13
|
+
class InvalidVolume(ValueError):
|
|
14
|
+
"""
|
|
15
|
+
Raised when the volume passed to BallFilter is too small or does not meet
|
|
16
|
+
requirements.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
pass
|
|
16
20
|
|
|
17
21
|
|
|
18
22
|
@lru_cache(maxsize=50)
|
|
19
23
|
def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray:
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
24
|
+
"""
|
|
25
|
+
Create a spherical kernel.
|
|
26
|
+
|
|
27
|
+
This is done by:
|
|
28
|
+
1. Generating a binary sphere at a resolution *upscale_factor* larger
|
|
29
|
+
than desired.
|
|
30
|
+
2. Downscaling the binary sphere to get a 'fuzzy' sphere at the
|
|
31
|
+
original intended scale
|
|
32
|
+
"""
|
|
27
33
|
upscale_factor: int = 7
|
|
28
34
|
upscaled_kernel_shape = (
|
|
29
35
|
upscale_factor * ball_xy_size,
|
|
@@ -42,11 +48,11 @@ def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray:
|
|
|
42
48
|
upscaled_ball_radius,
|
|
43
49
|
upscaled_ball_centre_position,
|
|
44
50
|
)
|
|
45
|
-
sphere_kernel = sphere_kernel.astype(np.
|
|
51
|
+
sphere_kernel = sphere_kernel.astype(np.float32)
|
|
46
52
|
kernel = bin_mean_3d(
|
|
47
53
|
sphere_kernel,
|
|
48
|
-
bin_height=upscale_factor,
|
|
49
54
|
bin_width=upscale_factor,
|
|
55
|
+
bin_height=upscale_factor,
|
|
50
56
|
bin_depth=upscale_factor,
|
|
51
57
|
)
|
|
52
58
|
|
|
@@ -59,359 +65,351 @@ def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray:
|
|
|
59
65
|
return kernel
|
|
60
66
|
|
|
61
67
|
|
|
62
|
-
# volume indices/size is 64 bit for very large brains(!)
|
|
63
|
-
spec = [
|
|
64
|
-
("ball_xy_size", types.uint32),
|
|
65
|
-
("ball_z_size", types.uint32),
|
|
66
|
-
("tile_step_width", types.uint64),
|
|
67
|
-
("tile_step_height", types.uint64),
|
|
68
|
-
("THRESHOLD_VALUE", types.uint32),
|
|
69
|
-
("SOMA_CENTRE_VALUE", types.uint32),
|
|
70
|
-
("overlap_fraction", types.float64),
|
|
71
|
-
("overlap_threshold", types.float64),
|
|
72
|
-
("middle_z_idx", types.uint32),
|
|
73
|
-
("_num_z_added", types.uint32),
|
|
74
|
-
("kernel", float_3d_type),
|
|
75
|
-
("volume", uint32_3d_type),
|
|
76
|
-
("inside_brain_tiles", bool_3d_type),
|
|
77
|
-
]
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
@jitclass(spec=spec)
|
|
81
68
|
class BallFilter:
|
|
82
69
|
"""
|
|
83
70
|
A 3D ball filter.
|
|
84
71
|
|
|
85
|
-
This runs a spherical kernel across the
|
|
72
|
+
This runs a spherical kernel across the 2d planar dimensions
|
|
86
73
|
of a *ball_z_size* stack of planes, and marks pixels in the middle
|
|
87
|
-
plane of the stack that have a high enough intensity
|
|
88
|
-
spherical kernel.
|
|
74
|
+
plane of the stack that have a high enough intensity over the
|
|
75
|
+
the spherical kernel.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
plane_height, plane_width : int
|
|
80
|
+
Height/width of the planes.
|
|
81
|
+
ball_xy_size : int
|
|
82
|
+
Diameter of the spherical kernel in the x/y dimensions.
|
|
83
|
+
ball_z_size : int
|
|
84
|
+
Diameter of the spherical kernel in the z dimension.
|
|
85
|
+
Equal to the number of planes stacked to filter
|
|
86
|
+
the central plane of the stack.
|
|
87
|
+
overlap_fraction : float
|
|
88
|
+
The fraction of pixels within the spherical kernel that
|
|
89
|
+
have to be over *threshold_value* for a pixel to be marked
|
|
90
|
+
as having a high intensity.
|
|
91
|
+
threshold_value : int
|
|
92
|
+
Value above which an individual pixel is considered to have
|
|
93
|
+
a high intensity.
|
|
94
|
+
soma_centre_value : int
|
|
95
|
+
Value used to mark pixels with a high enough intensity.
|
|
96
|
+
tile_height, tile_width : int
|
|
97
|
+
Width/height of individual tiles in the mask generated by
|
|
98
|
+
2D filtering.
|
|
99
|
+
dtype : str
|
|
100
|
+
The data-type of the input planes and the type to use internally.
|
|
101
|
+
E.g. "float32".
|
|
102
|
+
batch_size: int
|
|
103
|
+
The number of planes that will be typically passed in a single batch to
|
|
104
|
+
`append`. This is only used to calculate `num_batches_before_ready`.
|
|
105
|
+
Defaults to 1.
|
|
106
|
+
torch_device: str
|
|
107
|
+
The device on which the data and processing occurs on. Can be e.g.
|
|
108
|
+
"cpu", "cuda" etc. Defaults to "cpu". Any data passed to the filter
|
|
109
|
+
must be on this device. Returned data will also be on this device.
|
|
110
|
+
use_mask : bool
|
|
111
|
+
Whether tiling masks will be used in `append`. If False, tile masks
|
|
112
|
+
won't be passed in and/or will be ignored. Defaults to True.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
num_batches_before_ready: int
|
|
89
116
|
"""
|
|
117
|
+
The number of batches of size `batch_size` passed to `append`
|
|
118
|
+
before `ready` would return True.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
# the inside brain tiled mask, if tiles are used (use_mask is True)
|
|
122
|
+
inside_brain_tiles: Optional[torch.Tensor] = None
|
|
90
123
|
|
|
91
124
|
def __init__(
|
|
92
125
|
self,
|
|
93
|
-
plane_width: int,
|
|
94
126
|
plane_height: int,
|
|
127
|
+
plane_width: int,
|
|
95
128
|
ball_xy_size: int,
|
|
96
129
|
ball_z_size: int,
|
|
97
130
|
overlap_fraction: float,
|
|
98
|
-
tile_step_width: int,
|
|
99
|
-
tile_step_height: int,
|
|
100
131
|
threshold_value: int,
|
|
101
132
|
soma_centre_value: int,
|
|
133
|
+
tile_height: int,
|
|
134
|
+
tile_width: int,
|
|
135
|
+
dtype: str,
|
|
136
|
+
batch_size: int = 1,
|
|
137
|
+
torch_device: str = "cpu",
|
|
138
|
+
use_mask: bool = True,
|
|
102
139
|
):
|
|
103
|
-
"""
|
|
104
|
-
Parameters
|
|
105
|
-
----------
|
|
106
|
-
plane_width, plane_height :
|
|
107
|
-
Width/height of the planes.
|
|
108
|
-
ball_xy_size :
|
|
109
|
-
Diameter of the spherical kernel in the x/y dimensions.
|
|
110
|
-
ball_z_size :
|
|
111
|
-
Diameter of the spherical kernel in the z dimension.
|
|
112
|
-
Equal to the number of planes that stacked to filter
|
|
113
|
-
the central plane of the stack.
|
|
114
|
-
overlap_fraction :
|
|
115
|
-
The fraction of pixels within the spherical kernel that
|
|
116
|
-
have to be over *threshold_value* for a pixel to be marked
|
|
117
|
-
as having a high intensity.
|
|
118
|
-
tile_step_width, tile_step_height :
|
|
119
|
-
Width/height of individual tiles in the mask generated by
|
|
120
|
-
2D filtering.
|
|
121
|
-
threshold_value :
|
|
122
|
-
Value above which an individual pixel is considered to have
|
|
123
|
-
a high intensity.
|
|
124
|
-
soma_centre_value :
|
|
125
|
-
Value used to mark pixels with a high enough intensity.
|
|
126
|
-
"""
|
|
127
140
|
self.ball_xy_size = ball_xy_size
|
|
128
141
|
self.ball_z_size = ball_z_size
|
|
129
142
|
self.overlap_fraction = overlap_fraction
|
|
130
|
-
self.
|
|
131
|
-
self.
|
|
143
|
+
self.tile_step_height = tile_height
|
|
144
|
+
self.tile_step_width = tile_width
|
|
145
|
+
|
|
146
|
+
d1 = plane_height
|
|
147
|
+
d2 = plane_width
|
|
148
|
+
ball_xy_size = self.ball_xy_size
|
|
149
|
+
if d1 < ball_xy_size or d2 < ball_xy_size:
|
|
150
|
+
raise InvalidVolume(
|
|
151
|
+
f"Invalid plane size {d1}x{d2}. Needs to be at least "
|
|
152
|
+
f"{ball_xy_size} in each dimension"
|
|
153
|
+
)
|
|
132
154
|
|
|
133
155
|
self.THRESHOLD_VALUE = threshold_value
|
|
134
156
|
self.SOMA_CENTRE_VALUE = soma_centre_value
|
|
135
157
|
|
|
136
|
-
#
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
self.
|
|
140
|
-
|
|
141
|
-
|
|
158
|
+
# kernel comes in as XYZ, change to ZYX so it aligns with data
|
|
159
|
+
kernel = np.moveaxis(get_kernel(ball_xy_size, self.ball_z_size), 2, 0)
|
|
160
|
+
self.overlap_threshold = np.sum(self.overlap_fraction * kernel)
|
|
161
|
+
self.kernel_xy_size = kernel.shape[-2:]
|
|
162
|
+
self.kernel_z_size = self.ball_z_size
|
|
163
|
+
|
|
164
|
+
# convert to right type and pin for faster copying
|
|
165
|
+
kernel = torch.from_numpy(kernel).type(getattr(torch, dtype))
|
|
166
|
+
if torch_device != "cpu":
|
|
167
|
+
# torch at one point threw a cuda memory error when splitting cells
|
|
168
|
+
# on cpu due to pinning. It's best to only pin on using cuda
|
|
169
|
+
kernel.pin_memory()
|
|
170
|
+
# add 2 dimensions at the start so we have 11ZYX We need this shape in
|
|
171
|
+
# the conv step
|
|
172
|
+
self.kernel = (
|
|
173
|
+
kernel.unsqueeze(0)
|
|
174
|
+
.unsqueeze(0)
|
|
175
|
+
.to(device=torch_device, non_blocking=True)
|
|
176
|
+
)
|
|
142
177
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
178
|
+
self.num_batches_before_ready = int(
|
|
179
|
+
math.ceil(self.ball_z_size / batch_size)
|
|
180
|
+
)
|
|
181
|
+
# Stores the current planes that are being filtered. Start with no data
|
|
182
|
+
self.volume = torch.empty(
|
|
183
|
+
(0, plane_height, plane_width),
|
|
184
|
+
dtype=getattr(torch, dtype),
|
|
148
185
|
)
|
|
149
186
|
# Index of the middle plane in the volume
|
|
150
|
-
self.middle_z_idx = int(np.floor(ball_z_size / 2))
|
|
151
|
-
self._num_z_added = 0
|
|
187
|
+
self.middle_z_idx = int(np.floor(self.ball_z_size / 2))
|
|
152
188
|
|
|
189
|
+
if not use_mask:
|
|
190
|
+
return
|
|
153
191
|
# first axis is z
|
|
154
|
-
|
|
192
|
+
n_vertical_tiles = int(np.ceil(plane_height / self.tile_step_height))
|
|
193
|
+
n_horizontal_tiles = int(np.ceil(plane_width / self.tile_step_width))
|
|
194
|
+
# Stores tile masks. We start with no data
|
|
195
|
+
self.inside_brain_tiles = torch.empty(
|
|
155
196
|
(
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
197
|
+
0,
|
|
198
|
+
n_vertical_tiles,
|
|
199
|
+
n_horizontal_tiles,
|
|
159
200
|
),
|
|
160
|
-
dtype=
|
|
201
|
+
dtype=torch.bool,
|
|
161
202
|
)
|
|
162
203
|
|
|
163
204
|
@property
|
|
164
|
-
def
|
|
205
|
+
def first_valid_plane(self) -> int:
|
|
165
206
|
"""
|
|
166
|
-
|
|
207
|
+
The index in `self.volume` (or the planes passed in) that will be the
|
|
208
|
+
first plane returned from `get_processed_planes`.
|
|
209
|
+
|
|
210
|
+
E.g. if `ball_z_size` is 3, then this may return 1. Meaning the second
|
|
211
|
+
plane passed to `append` (index 1), will be the first returned plane
|
|
212
|
+
by `get_processed_planes`.
|
|
167
213
|
"""
|
|
168
|
-
return self.
|
|
214
|
+
return int(math.floor(self.ball_z_size / 2))
|
|
169
215
|
|
|
170
|
-
|
|
216
|
+
@property
|
|
217
|
+
def remaining_planes(self) -> int:
|
|
171
218
|
"""
|
|
172
|
-
|
|
219
|
+
The number of planes in `self.volume` (or the planes passed in) that
|
|
220
|
+
will remain unprocessed after all the planes have been `walk`ed
|
|
221
|
+
and `get_processed_planes` called.
|
|
222
|
+
|
|
223
|
+
E.g. if `ball_z_size` is 3, then this may return 1. Meaning the last
|
|
224
|
+
plane passed to `append`, will never be returned by
|
|
225
|
+
`get_processed_planes` because the filter "center" never overlapped
|
|
226
|
+
with it.
|
|
173
227
|
"""
|
|
174
|
-
|
|
175
|
-
assert [e for e in plane.shape[:2]] == [
|
|
176
|
-
e for e in self.volume.shape[1:]
|
|
177
|
-
], 'plane shape mismatch, expected "{}", got "{}"'.format(
|
|
178
|
-
[e for e in self.volume.shape[1:]],
|
|
179
|
-
[e for e in plane.shape[:2]],
|
|
180
|
-
)
|
|
181
|
-
assert [e for e in mask.shape[:2]] == [
|
|
182
|
-
e for e in self.inside_brain_tiles.shape[1:]
|
|
183
|
-
], 'mask shape mismatch, expected"{}", got {}"'.format(
|
|
184
|
-
[e for e in self.inside_brain_tiles.shape[1:]],
|
|
185
|
-
[e for e in mask.shape[:2]],
|
|
186
|
-
)
|
|
228
|
+
return self.ball_z_size - self.first_valid_plane - 1
|
|
187
229
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
230
|
+
@property
|
|
231
|
+
def ready(self) -> bool:
|
|
232
|
+
"""
|
|
233
|
+
Return whether enough planes have been appended to run the filter
|
|
234
|
+
using `walk`.
|
|
235
|
+
"""
|
|
236
|
+
return self.volume.shape[0] >= self.kernel_z_size
|
|
237
|
+
|
|
238
|
+
def append(
|
|
239
|
+
self, planes: torch.Tensor, masks: Optional[torch.Tensor] = None
|
|
240
|
+
) -> None:
|
|
241
|
+
"""
|
|
242
|
+
Add a new z-stack to the filter.
|
|
195
243
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
244
|
+
Previous stacks passed to `append` are removed, except enough planes
|
|
245
|
+
at the top of the previous z-stack to provide padding so we can filter
|
|
246
|
+
starting from the first plane in `planes`. The first time we call
|
|
247
|
+
`append`, `first_valid_plane` is the first plane to actually be
|
|
248
|
+
filtered in the z-stack due to lack of padding.
|
|
199
249
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
self.inside_brain_tiles[idx, :, :] = mask
|
|
250
|
+
So make sure to call `walk`/`get_processed_planes` before calling
|
|
251
|
+
`append` again.
|
|
203
252
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
planes : torch.Tensor
|
|
256
|
+
The z-stack. There can be one or more planes in the stack, but it
|
|
257
|
+
must have 3 dimensions. Input data is not modified.
|
|
258
|
+
masks : torch.Tensor
|
|
259
|
+
A z-stack tile mask, indicating for each tile whether it's in or
|
|
260
|
+
outside the brain. If the latter it's excluded.
|
|
261
|
+
|
|
262
|
+
If `use_mask` was True, this must be provided. If False, this
|
|
263
|
+
parameter will be ignored.
|
|
264
|
+
|
|
265
|
+
Input data is not modified.
|
|
207
266
|
"""
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
tile_mask_covered_img_height = (
|
|
219
|
-
self.inside_brain_tiles.shape[2] * self.tile_step_height
|
|
220
|
-
)
|
|
221
|
-
# Get maximum offsets for the ball
|
|
222
|
-
max_width = tile_mask_covered_img_width - self.ball_xy_size
|
|
223
|
-
max_height = tile_mask_covered_img_height - self.ball_xy_size
|
|
224
|
-
|
|
225
|
-
# we have to pass the raw volume so walk doesn't use its edits as it
|
|
226
|
-
# processes the volume. self.volume is the one edited in place
|
|
227
|
-
input_volume = self.volume.copy()
|
|
228
|
-
|
|
229
|
-
if parallel:
|
|
230
|
-
_walk_parallel(
|
|
231
|
-
max_height,
|
|
232
|
-
max_width,
|
|
233
|
-
self.tile_step_width,
|
|
234
|
-
self.tile_step_height,
|
|
235
|
-
self.inside_brain_tiles,
|
|
236
|
-
input_volume,
|
|
237
|
-
self.volume,
|
|
238
|
-
self.kernel,
|
|
239
|
-
ball_radius,
|
|
240
|
-
self.middle_z_idx,
|
|
241
|
-
self.overlap_threshold,
|
|
242
|
-
self.THRESHOLD_VALUE,
|
|
243
|
-
self.SOMA_CENTRE_VALUE,
|
|
267
|
+
if self.volume.shape[0]:
|
|
268
|
+
if self.volume.shape[0] < self.kernel_z_size:
|
|
269
|
+
num_remaining_with_padding = 0
|
|
270
|
+
else:
|
|
271
|
+
num_remaining = self.kernel_z_size - (self.middle_z_idx + 1)
|
|
272
|
+
num_remaining_with_padding = num_remaining + self.middle_z_idx
|
|
273
|
+
|
|
274
|
+
self.volume = torch.cat(
|
|
275
|
+
[self.volume[-num_remaining_with_padding:, :, :], planes],
|
|
276
|
+
dim=0,
|
|
244
277
|
)
|
|
278
|
+
|
|
279
|
+
if self.inside_brain_tiles is not None:
|
|
280
|
+
self.inside_brain_tiles = torch.cat(
|
|
281
|
+
[
|
|
282
|
+
self.inside_brain_tiles[
|
|
283
|
+
-num_remaining_with_padding:, :, :
|
|
284
|
+
],
|
|
285
|
+
masks,
|
|
286
|
+
],
|
|
287
|
+
dim=0,
|
|
288
|
+
)
|
|
245
289
|
else:
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
self.tile_step_width,
|
|
250
|
-
self.tile_step_height,
|
|
251
|
-
self.inside_brain_tiles,
|
|
252
|
-
input_volume,
|
|
253
|
-
self.volume,
|
|
254
|
-
self.kernel,
|
|
255
|
-
ball_radius,
|
|
256
|
-
self.middle_z_idx,
|
|
257
|
-
self.overlap_threshold,
|
|
258
|
-
self.THRESHOLD_VALUE,
|
|
259
|
-
self.SOMA_CENTRE_VALUE,
|
|
260
|
-
)
|
|
290
|
+
self.volume = planes.clone()
|
|
291
|
+
if self.inside_brain_tiles is not None:
|
|
292
|
+
self.inside_brain_tiles = masks.clone()
|
|
261
293
|
|
|
294
|
+
def get_processed_planes(self) -> np.ndarray:
|
|
295
|
+
"""
|
|
296
|
+
After passing enough planes to `append`, and after `walk`, this returns
|
|
297
|
+
a copy of the processed planes as a numpy z-stack.
|
|
262
298
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
x_end: int,
|
|
268
|
-
y_start: int,
|
|
269
|
-
y_end: int,
|
|
270
|
-
overlap_threshold: float,
|
|
271
|
-
threshold_value: int,
|
|
272
|
-
kernel: np.ndarray,
|
|
273
|
-
) -> bool: # Highly optimised because most time critical
|
|
274
|
-
"""
|
|
275
|
-
For each pixel in cube in volume that is greater than THRESHOLD_VALUE, sum
|
|
276
|
-
up the corresponding pixels in *kernel*. If the total is less than
|
|
277
|
-
overlap_threshold, return False, otherwise return True.
|
|
299
|
+
It only starts returning planes corresponding to plane
|
|
300
|
+
`first_valid_plane` relative to the first planes passed to `append`.
|
|
301
|
+
E.g. if `ball_z_size` is 3 and `first_valid_plane` is 1, and you passed
|
|
302
|
+
5 planes total to `append`, then this will have returned planes [1, 3].
|
|
278
303
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
304
|
+
Notice the last plane was not included, because we return only "middle"
|
|
305
|
+
planes - planes that can correspond to the center of a ball.
|
|
306
|
+
"""
|
|
307
|
+
if not self.ready:
|
|
308
|
+
raise TypeError("Not enough planes were appended")
|
|
309
|
+
|
|
310
|
+
num_processed = self.volume.shape[0] - self.kernel_z_size + 1
|
|
311
|
+
assert num_processed
|
|
312
|
+
middle = self.middle_z_idx
|
|
313
|
+
planes = (
|
|
314
|
+
self.volume[middle : middle + num_processed, :, :]
|
|
315
|
+
.cpu()
|
|
316
|
+
.numpy()
|
|
317
|
+
.copy()
|
|
318
|
+
)
|
|
319
|
+
return planes
|
|
282
320
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
3D array.
|
|
287
|
-
x_start, x_end, y_start, y_end :
|
|
288
|
-
The start and end indices in volume that form the cube. End is
|
|
289
|
-
exclusive
|
|
290
|
-
overlap_threshold :
|
|
291
|
-
Threshold above which to return True.
|
|
292
|
-
threshold_value :
|
|
293
|
-
Value above which a pixel is marked as being part of a cell.
|
|
294
|
-
kernel :
|
|
295
|
-
3D array, with the same shape as *cube* in the volume.
|
|
296
|
-
"""
|
|
297
|
-
current_overlap_value = 0.0
|
|
298
|
-
|
|
299
|
-
middle = np.floor(volume.shape[0] / 2) + 1
|
|
300
|
-
halfway_overlap_thresh = (
|
|
301
|
-
overlap_threshold * 0.4
|
|
302
|
-
) # FIXME: do not hard code value
|
|
303
|
-
|
|
304
|
-
for z in range(volume.shape[0]):
|
|
305
|
-
# TODO: OPTIMISE: step from middle to outer boundaries to check
|
|
306
|
-
# more data first
|
|
307
|
-
#
|
|
308
|
-
# If halfway through the array, and the overlap value isn't more than
|
|
309
|
-
# 0.4 * the overlap threshold, return
|
|
310
|
-
if z == middle and current_overlap_value < halfway_overlap_thresh:
|
|
311
|
-
return False # DEBUG: optimisation attempt
|
|
312
|
-
|
|
313
|
-
for y in range(y_start, y_end):
|
|
314
|
-
for x in range(x_start, x_end):
|
|
315
|
-
# includes self.SOMA_CENTRE_VALUE
|
|
316
|
-
if volume[z, x, y] >= threshold_value:
|
|
317
|
-
# x/y must be shifted in kernel because we x/y is relative
|
|
318
|
-
# to the full volume, so shift it to relative to the cube
|
|
319
|
-
current_overlap_value += kernel[
|
|
320
|
-
x - x_start, y - y_start, z
|
|
321
|
-
]
|
|
322
|
-
|
|
323
|
-
return current_overlap_value > overlap_threshold
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
@njit
|
|
327
|
-
def _is_tile_to_check(
|
|
328
|
-
x: int,
|
|
329
|
-
y: int,
|
|
330
|
-
middle_z: int,
|
|
331
|
-
tile_step_width: int,
|
|
332
|
-
tile_step_height: int,
|
|
333
|
-
inside_brain_tiles: np.ndarray,
|
|
334
|
-
) -> bool: # Highly optimised because most time critical
|
|
335
|
-
"""
|
|
336
|
-
Check if the tile containing pixel (x, y) is a tile that needs checking.
|
|
337
|
-
"""
|
|
338
|
-
x_in_mask = x // tile_step_width # TEST: test bounds (-1 range)
|
|
339
|
-
y_in_mask = y // tile_step_height # TEST: test bounds (-1 range)
|
|
340
|
-
return inside_brain_tiles[middle_z, x_in_mask, y_in_mask]
|
|
321
|
+
def walk(self) -> None:
|
|
322
|
+
"""
|
|
323
|
+
Applies the filter to all the planes passed to `append`.
|
|
341
324
|
|
|
325
|
+
May only be called if `ready` was True.
|
|
342
326
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
327
|
+
You can get the processed planes from `get_processed_planes`.
|
|
328
|
+
"""
|
|
329
|
+
if not self.ready:
|
|
330
|
+
raise TypeError("Called walk before enough planes were appended")
|
|
331
|
+
|
|
332
|
+
_walk(
|
|
333
|
+
self.kernel_z_size,
|
|
334
|
+
self.middle_z_idx,
|
|
335
|
+
self.tile_step_height,
|
|
336
|
+
self.tile_step_width,
|
|
337
|
+
self.overlap_threshold,
|
|
338
|
+
self.THRESHOLD_VALUE,
|
|
339
|
+
self.SOMA_CENTRE_VALUE,
|
|
340
|
+
self.kernel,
|
|
341
|
+
self.volume,
|
|
342
|
+
self.inside_brain_tiles,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@torch.jit.script
|
|
347
|
+
def _walk(
|
|
348
|
+
kernel_z_size: int,
|
|
349
|
+
middle: int,
|
|
347
350
|
tile_step_height: int,
|
|
348
|
-
|
|
349
|
-
input_volume: np.ndarray,
|
|
350
|
-
volume: np.ndarray,
|
|
351
|
-
kernel: np.ndarray,
|
|
352
|
-
ball_radius: int,
|
|
353
|
-
middle_z: int,
|
|
351
|
+
tile_step_width: int,
|
|
354
352
|
overlap_threshold: float,
|
|
355
353
|
threshold_value: int,
|
|
356
354
|
soma_centre_value: int,
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
355
|
+
kernel: torch.Tensor,
|
|
356
|
+
volume: torch.Tensor,
|
|
357
|
+
inside_brain_tiles: Optional[torch.Tensor],
|
|
358
|
+
):
|
|
359
|
+
num_process = volume.shape[0] - kernel_z_size + 1
|
|
360
|
+
height, width = volume.shape[1:]
|
|
361
|
+
num_z = kernel.shape[2]
|
|
362
|
+
|
|
363
|
+
# threshold volume so it's zero/one. And add two dims at start so
|
|
364
|
+
# it's 11ZYX
|
|
365
|
+
volume_tresh = (
|
|
366
|
+
(volume >= threshold_value)
|
|
367
|
+
.unsqueeze(0)
|
|
368
|
+
.unsqueeze(0)
|
|
369
|
+
.type(kernel.dtype)
|
|
370
|
+
)
|
|
361
371
|
|
|
362
|
-
|
|
372
|
+
# we do a plane at a time, volume: i:i+num_z, for plane i+middle
|
|
373
|
+
for i in range(num_process):
|
|
374
|
+
# spherical kernel is symmetric so convolution=correlation. Use
|
|
375
|
+
# binary threshold mask over the kernel to sum the value of the
|
|
376
|
+
# kernel at voxels that are bright
|
|
377
|
+
overlap = F.conv3d(
|
|
378
|
+
volume_tresh[:, :, i : i + num_z, :, :],
|
|
379
|
+
kernel,
|
|
380
|
+
stride=1,
|
|
381
|
+
padding="valid",
|
|
382
|
+
)[0, 0, 0, :, :]
|
|
383
|
+
overlaps = overlap > overlap_threshold
|
|
384
|
+
|
|
385
|
+
# only edit the volume that is valid - conv excludes edges so we
|
|
386
|
+
# only edit the plane parts returned by conv3d
|
|
387
|
+
height_valid, width_valid = overlaps.shape
|
|
388
|
+
height_offset = (height - height_valid) // 2
|
|
389
|
+
width_offset = (width - width_valid) // 2
|
|
390
|
+
sub_volume = volume[
|
|
391
|
+
i + middle,
|
|
392
|
+
height_offset : height_offset + height_valid,
|
|
393
|
+
width_offset : width_offset + width_valid,
|
|
394
|
+
]
|
|
395
|
+
|
|
396
|
+
# do we use tile masks or just overlapping?
|
|
397
|
+
if inside_brain_tiles is not None:
|
|
398
|
+
# unfold tiles to cover the full voxel area each tile covers
|
|
399
|
+
inside = (
|
|
400
|
+
inside_brain_tiles[i + middle, :, :]
|
|
401
|
+
.repeat_interleave(tile_step_height, dim=0)
|
|
402
|
+
.repeat_interleave(tile_step_width, dim=1)
|
|
403
|
+
)
|
|
404
|
+
# again only process pixels in the valid area
|
|
405
|
+
inside = inside[
|
|
406
|
+
height_offset : height_offset + height_valid,
|
|
407
|
+
width_offset : width_offset + width_valid,
|
|
408
|
+
]
|
|
363
409
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
inside the brain or not. Tiles outside the brain are skipped.
|
|
371
|
-
input_volume :
|
|
372
|
-
3D array containing the plane-filtered data passed to the function
|
|
373
|
-
before walking. volume is edited in place, so this is the original
|
|
374
|
-
volume to prevent the changes for some cubes affective other cubes
|
|
375
|
-
during a single walk call.
|
|
376
|
-
volume :
|
|
377
|
-
3D array containing the plane-filtered data - edited in place.
|
|
378
|
-
kernel :
|
|
379
|
-
3D array
|
|
380
|
-
ball_radius :
|
|
381
|
-
Radius of the ball in the xy plane.
|
|
382
|
-
soma_centre_value :
|
|
383
|
-
Value that is used to mark pixels in *volume*.
|
|
384
|
-
|
|
385
|
-
Notes
|
|
386
|
-
-----
|
|
387
|
-
Warning: modifies volume in place!
|
|
388
|
-
"""
|
|
389
|
-
for y in prange(max_height):
|
|
390
|
-
for x in prange(max_width):
|
|
391
|
-
ball_centre_x = x + ball_radius
|
|
392
|
-
ball_centre_y = y + ball_radius
|
|
393
|
-
if _is_tile_to_check(
|
|
394
|
-
ball_centre_x,
|
|
395
|
-
ball_centre_y,
|
|
396
|
-
middle_z,
|
|
397
|
-
tile_step_width,
|
|
398
|
-
tile_step_height,
|
|
399
|
-
inside_brain_tiles,
|
|
400
|
-
):
|
|
401
|
-
if _cube_overlaps(
|
|
402
|
-
input_volume,
|
|
403
|
-
x,
|
|
404
|
-
x + kernel.shape[0],
|
|
405
|
-
y,
|
|
406
|
-
y + kernel.shape[1],
|
|
407
|
-
overlap_threshold,
|
|
408
|
-
threshold_value,
|
|
409
|
-
kernel,
|
|
410
|
-
):
|
|
411
|
-
volume[middle_z, ball_centre_x, ball_centre_y] = (
|
|
412
|
-
soma_centre_value
|
|
413
|
-
)
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
_walk_parallel = njit(parallel=True)(_walk_base)
|
|
417
|
-
_walk_single = njit(parallel=False)(_walk_base)
|
|
410
|
+
# must have enough ball overlap to be bright/tile is in brain
|
|
411
|
+
sub_volume[torch.logical_and(overlaps, inside)] = soma_centre_value
|
|
412
|
+
|
|
413
|
+
else:
|
|
414
|
+
# must have enough ball overlap to be bright
|
|
415
|
+
sub_volume[overlaps] = soma_centre_value
|