cellfinder 1.3.3__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cellfinder might be problematic. Click here for more details.
- cellfinder/core/classify/classify.py +3 -2
- cellfinder/core/detect/detect.py +118 -183
- cellfinder/core/detect/filters/plane/classical_filter.py +339 -37
- cellfinder/core/detect/filters/plane/plane_filter.py +137 -55
- cellfinder/core/detect/filters/plane/tile_walker.py +126 -60
- cellfinder/core/detect/filters/setup_filters.py +422 -65
- cellfinder/core/detect/filters/volume/ball_filter.py +313 -315
- cellfinder/core/detect/filters/volume/structure_detection.py +73 -35
- cellfinder/core/detect/filters/volume/structure_splitting.py +160 -96
- cellfinder/core/detect/filters/volume/volume_filter.py +444 -123
- cellfinder/core/main.py +6 -2
- cellfinder/core/tools/IO.py +45 -0
- cellfinder/core/tools/threading.py +380 -0
- cellfinder/core/tools/tools.py +128 -6
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0.dist-info}/METADATA +5 -4
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0.dist-info}/RECORD +20 -18
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0.dist-info}/WHEEL +1 -1
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0.dist-info}/LICENSE +0 -0
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0.dist-info}/entry_points.txt +0 -0
- {cellfinder-1.3.3.dist-info → cellfinder-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,45 +1,347 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
2
4
|
from scipy.ndimage import gaussian_filter, laplace
|
|
3
5
|
from scipy.signal import medfilt2d
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
8
|
+
@torch.jit.script
|
|
9
|
+
def normalize(
|
|
10
|
+
filtered_planes: torch.Tensor,
|
|
11
|
+
flip: bool,
|
|
12
|
+
max_value: float = 1.0,
|
|
13
|
+
) -> None:
|
|
9
14
|
"""
|
|
10
|
-
|
|
15
|
+
Normalizes the 3d tensor so each z-plane is independently scaled to be
|
|
16
|
+
in the [0, max_value] range. If `flip` is `True`, the sign of the tensor
|
|
17
|
+
values are flipped before any processing.
|
|
11
18
|
|
|
12
|
-
|
|
19
|
+
It is done to filtered_planes inplace.
|
|
20
|
+
"""
|
|
21
|
+
num_z = filtered_planes.shape[0]
|
|
22
|
+
filtered_planes_1d = filtered_planes.view(num_z, -1)
|
|
23
|
+
|
|
24
|
+
if flip:
|
|
25
|
+
filtered_planes_1d.mul_(-1)
|
|
26
|
+
|
|
27
|
+
planes_min = torch.min(filtered_planes_1d, dim=1, keepdim=True)[0]
|
|
28
|
+
filtered_planes_1d.sub_(planes_min)
|
|
29
|
+
# take max after subtraction
|
|
30
|
+
planes_max = torch.max(filtered_planes_1d, dim=1, keepdim=True)[0]
|
|
31
|
+
# if min = max = zero, divide by 1 - it'll stay zero
|
|
32
|
+
planes_max[planes_max == 0] = 1
|
|
33
|
+
filtered_planes_1d.div_(planes_max)
|
|
34
|
+
|
|
35
|
+
if max_value != 1.0:
|
|
36
|
+
# To leave room to label in the 3d detection.
|
|
37
|
+
filtered_planes_1d.mul_(max_value)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@torch.jit.script
|
|
41
|
+
def filter_for_peaks(
|
|
42
|
+
planes: torch.Tensor,
|
|
43
|
+
med_kernel: torch.Tensor,
|
|
44
|
+
gauss_kernel: torch.Tensor,
|
|
45
|
+
gauss_kernel_size: int,
|
|
46
|
+
lap_kernel: torch.Tensor,
|
|
47
|
+
device: str,
|
|
48
|
+
clipping_value: float,
|
|
49
|
+
) -> torch.Tensor:
|
|
50
|
+
"""
|
|
51
|
+
Takes the 3d z-stack and returns a new z-stack where the peaks are
|
|
52
|
+
highlighted.
|
|
53
|
+
|
|
54
|
+
It applies a median filter -> gaussian filter -> laplacian filter.
|
|
55
|
+
"""
|
|
56
|
+
filtered_planes = planes.unsqueeze(1) # ZYX -> ZCYX input, C=channels
|
|
57
|
+
|
|
58
|
+
# ------------------ median filter ------------------
|
|
59
|
+
# extracts patches to compute median over for each pixel
|
|
60
|
+
# We go from ZCYX -> ZCYX, C=1 to C=9 with C containing the elements around
|
|
61
|
+
# each Z,X,Y voxel over which we compute the median
|
|
62
|
+
# Zero padding is ok here
|
|
63
|
+
filtered_planes = F.conv2d(filtered_planes, med_kernel, padding="same")
|
|
64
|
+
# we're going back to ZCYX=Z1YX by taking median of patches in C dim
|
|
65
|
+
filtered_planes = filtered_planes.median(dim=1, keepdim=True)[0]
|
|
66
|
+
|
|
67
|
+
# ------------------ gaussian filter ------------------
|
|
68
|
+
# normalize the input data to 0-1 range. Otherwise, if the values are
|
|
69
|
+
# large, we'd need a float64 so conv result is accurate
|
|
70
|
+
normalize(filtered_planes, flip=False)
|
|
71
|
+
|
|
72
|
+
# we need to do reflection padding around the tensor for parity with scipy
|
|
73
|
+
# gaussian filtering. Scipy does reflection in a manner typically called
|
|
74
|
+
# symmetric: (dcba|abcd|dcba). Torch does it like this: (dcb|abcd|cba). So
|
|
75
|
+
# we manually do symmetric padding below
|
|
76
|
+
pad = gauss_kernel_size // 2
|
|
77
|
+
padding_mode = "reflect"
|
|
78
|
+
# if data is too small for reflect, just use constant border value
|
|
79
|
+
if pad >= filtered_planes.shape[-1] or pad >= filtered_planes.shape[-2]:
|
|
80
|
+
padding_mode = "replicate"
|
|
81
|
+
filtered_planes = F.pad(filtered_planes, (pad,) * 4, padding_mode, 0.0)
|
|
82
|
+
# We reflected torch style, so copy/shift everything by one to be symmetric
|
|
83
|
+
filtered_planes[:, :, :pad, :] = filtered_planes[
|
|
84
|
+
:, :, 1 : pad + 1, :
|
|
85
|
+
].clone()
|
|
86
|
+
filtered_planes[:, :, -pad:, :] = filtered_planes[
|
|
87
|
+
:, :, -pad - 1 : -1, :
|
|
88
|
+
].clone()
|
|
89
|
+
filtered_planes[:, :, :, :pad] = filtered_planes[
|
|
90
|
+
:, :, :, 1 : pad + 1
|
|
91
|
+
].clone()
|
|
92
|
+
filtered_planes[:, :, :, -pad:] = filtered_planes[
|
|
93
|
+
:, :, :, -pad - 1 : -1
|
|
94
|
+
].clone()
|
|
95
|
+
|
|
96
|
+
# We apply the 1D gaussian filter twice, once for Y and once for X. The
|
|
97
|
+
# filter shape passed in is 11K1 or 111K, depending on device. Where
|
|
98
|
+
# K=filter size
|
|
99
|
+
# see https://discuss.pytorch.org/t/performance-issue-for-conv2d-with-1d-
|
|
100
|
+
# filter-along-a-dim/201734/2 for the reason for the moveaxis depending
|
|
101
|
+
# on the device
|
|
102
|
+
if device == "cpu":
|
|
103
|
+
# kernel shape is 11K1. First do Y (second to last axis)
|
|
104
|
+
filtered_planes = F.conv2d(
|
|
105
|
+
filtered_planes, gauss_kernel, padding="valid"
|
|
106
|
+
)
|
|
107
|
+
# To do X, exchange X,Y axis, filter, change back. On CPU, Y (second
|
|
108
|
+
# to last) axis is faster.
|
|
109
|
+
filtered_planes = F.conv2d(
|
|
110
|
+
filtered_planes.moveaxis(-1, -2), gauss_kernel, padding="valid"
|
|
111
|
+
).moveaxis(-1, -2)
|
|
112
|
+
else:
|
|
113
|
+
# kernel shape is 111K
|
|
114
|
+
# First do Y (second to last axis). Exchange X,Y axis, filter, change
|
|
115
|
+
# back. On CUDA, X (last) axis is faster.
|
|
116
|
+
filtered_planes = F.conv2d(
|
|
117
|
+
filtered_planes.moveaxis(-1, -2), gauss_kernel, padding="valid"
|
|
118
|
+
).moveaxis(-1, -2)
|
|
119
|
+
# now do X, last axis
|
|
120
|
+
filtered_planes = F.conv2d(
|
|
121
|
+
filtered_planes, gauss_kernel, padding="valid"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# ------------------ laplacian filter ------------------
|
|
125
|
+
# it's a 2d filter. Need to pad using symmetric for scipy parity. But,
|
|
126
|
+
# torch doesn't have it, and we used a kernel of size 3, so for padding of
|
|
127
|
+
# 1, replicate == symmetric. That's enough for parity with past scipy. If
|
|
128
|
+
# we change kernel size in the future, we may have to do as above
|
|
129
|
+
padding = lap_kernel.shape[-1] // 2
|
|
130
|
+
filtered_planes = F.pad(filtered_planes, (padding,) * 4, "replicate")
|
|
131
|
+
filtered_planes = F.conv2d(filtered_planes, lap_kernel, padding="valid")
|
|
132
|
+
|
|
133
|
+
# we don't need the channel axis
|
|
134
|
+
filtered_planes = filtered_planes[:, 0, :, :]
|
|
135
|
+
|
|
136
|
+
# scale back to full scale, filtered values are negative so flip
|
|
137
|
+
normalize(filtered_planes, flip=True, max_value=clipping_value)
|
|
138
|
+
return filtered_planes
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class PeakEnhancer:
|
|
142
|
+
"""
|
|
143
|
+
A class that filters each plane in a z-stack such that peaks are
|
|
144
|
+
visualized.
|
|
145
|
+
|
|
146
|
+
It uses a series of 2D filters of median -> gaussian ->
|
|
147
|
+
laplacian. Then normalizes each plane to be between [0, clipping_value].
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
13
150
|
----------
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
#
|
|
44
|
-
|
|
45
|
-
|
|
151
|
+
torch_device: str
|
|
152
|
+
The device on which the data and processing occurs on. Can be e.g.
|
|
153
|
+
"cpu", "cuda" etc. Any data passed to the filter must be on this
|
|
154
|
+
device. Returned data will also be on this device.
|
|
155
|
+
dtype : torch.dtype
|
|
156
|
+
The data-type of the input planes and the type to use internally.
|
|
157
|
+
E.g. `torch.float32`.
|
|
158
|
+
clipping_value : int
|
|
159
|
+
The value such that after normalizing, the max value will be this
|
|
160
|
+
clipping_value.
|
|
161
|
+
laplace_gaussian_sigma : float
|
|
162
|
+
Size of the sigma for the gaussian filter.
|
|
163
|
+
use_scipy : bool
|
|
164
|
+
If running on the CPU whether to use the scipy filters or the same
|
|
165
|
+
pytorch filters used on CUDA. Scipy filters can be faster.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
# binary kernel that generates square patches for each pixel so we can find
|
|
169
|
+
# the median around the pixel
|
|
170
|
+
med_kernel: torch.Tensor
|
|
171
|
+
|
|
172
|
+
# gaussian 1D kernel with kernel/weight shape 11K1 or 111K, depending
|
|
173
|
+
# on device. Where K=filter size
|
|
174
|
+
gauss_kernel: torch.Tensor
|
|
175
|
+
|
|
176
|
+
# 2D laplacian kernel with kernel/weight shape KxK. Where
|
|
177
|
+
# K=filter size
|
|
178
|
+
lap_kernel: torch.Tensor
|
|
179
|
+
|
|
180
|
+
# the value such that after normalizing, the max value will be this
|
|
181
|
+
# clipping_value
|
|
182
|
+
clipping_value: float
|
|
183
|
+
|
|
184
|
+
# sigma value for gaussian filter
|
|
185
|
+
laplace_gaussian_sigma: float
|
|
186
|
+
|
|
187
|
+
# the torch device to run on. E.g. cpu/cuda.
|
|
188
|
+
torch_device: str
|
|
189
|
+
|
|
190
|
+
# when running on CPU whether to use pytorch or scipy for filters
|
|
191
|
+
use_scipy: bool
|
|
192
|
+
|
|
193
|
+
median_filter_size: int = 3
|
|
194
|
+
"""
|
|
195
|
+
The median filter size in x/y direction.
|
|
196
|
+
|
|
197
|
+
**Must** be odd.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
torch_device: str,
|
|
203
|
+
dtype: torch.dtype,
|
|
204
|
+
clipping_value: float,
|
|
205
|
+
laplace_gaussian_sigma: float,
|
|
206
|
+
use_scipy: bool,
|
|
207
|
+
):
|
|
208
|
+
super().__init__()
|
|
209
|
+
self.torch_device = torch_device.lower()
|
|
210
|
+
self.clipping_value = clipping_value
|
|
211
|
+
self.laplace_gaussian_sigma = laplace_gaussian_sigma
|
|
212
|
+
self.use_scipy = use_scipy
|
|
213
|
+
|
|
214
|
+
# all these kernels are odd in size
|
|
215
|
+
self.med_kernel = self._get_median_kernel(torch_device, dtype)
|
|
216
|
+
self.gauss_kernel = self._get_gaussian_kernel(
|
|
217
|
+
torch_device, dtype, laplace_gaussian_sigma
|
|
218
|
+
)
|
|
219
|
+
self.lap_kernel = self._get_laplacian_kernel(torch_device, dtype)
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def gaussian_filter_size(self) -> int:
|
|
223
|
+
"""
|
|
224
|
+
The gaussian filter 1d size.
|
|
225
|
+
|
|
226
|
+
It is odd.
|
|
227
|
+
"""
|
|
228
|
+
return 2 * int(round(4 * self.laplace_gaussian_sigma)) + 1
|
|
229
|
+
|
|
230
|
+
def _get_median_kernel(
|
|
231
|
+
self, torch_device: str, dtype: torch.dtype
|
|
232
|
+
) -> torch.Tensor:
|
|
233
|
+
"""
|
|
234
|
+
Gets a median patch generator kernel, already on the correct
|
|
235
|
+
device.
|
|
236
|
+
|
|
237
|
+
Based on how kornia does it for median filtering.
|
|
238
|
+
"""
|
|
239
|
+
# must be odd kernel
|
|
240
|
+
kernel_n = self.median_filter_size
|
|
241
|
+
if not (kernel_n % 2):
|
|
242
|
+
raise ValueError("The median filter size must be odd")
|
|
243
|
+
|
|
244
|
+
# extract patches to compute median over for each pixel. When passing
|
|
245
|
+
# input we go from ZCYX -> ZCYX, C=1 to C=9 and containing the elements
|
|
246
|
+
# around each Z,X,Y over which we can then compute the median
|
|
247
|
+
window_range = kernel_n * kernel_n # e.g. 3x3
|
|
248
|
+
kernel = torch.zeros(
|
|
249
|
+
(window_range, window_range), device=torch_device, dtype=dtype
|
|
250
|
+
)
|
|
251
|
+
idx = torch.arange(window_range, device=torch_device)
|
|
252
|
+
# diagonal of e.g. 9x9 is 1
|
|
253
|
+
kernel[idx, idx] = 1.0
|
|
254
|
+
# out channels, in channels, n*y, n*x. The kernel collects all the 3x3
|
|
255
|
+
# elements around a pixel, using a binary mask for each element, as a
|
|
256
|
+
# separate channel. So we go from 1 to 9 channels in the output
|
|
257
|
+
kernel = kernel.view(window_range, 1, kernel_n, kernel_n)
|
|
258
|
+
|
|
259
|
+
return kernel
|
|
260
|
+
|
|
261
|
+
def _get_gaussian_kernel(
|
|
262
|
+
self,
|
|
263
|
+
torch_device: str,
|
|
264
|
+
dtype: torch.dtype,
|
|
265
|
+
laplace_gaussian_sigma: float,
|
|
266
|
+
) -> torch.Tensor:
|
|
267
|
+
"""Gets the 1D gaussian kernel used to filter the data."""
|
|
268
|
+
# we do 2 1D filters, once on each y, x dim.
|
|
269
|
+
# shape of kernel will be 11K1 with dims Z, C, Y, X. C=1, Z is expanded
|
|
270
|
+
# to number of z during filtering.
|
|
271
|
+
kernel_size = self.gaussian_filter_size
|
|
272
|
+
|
|
273
|
+
# to get the values of a 1D gaussian kernel, we pass a single impulse
|
|
274
|
+
# data through the filter, which recovers the filter values. We do this
|
|
275
|
+
# because scipy doesn't make their kernel available in public API and
|
|
276
|
+
# we want parity with scipy filtering
|
|
277
|
+
impulse = np.zeros(kernel_size)
|
|
278
|
+
# the impulse needs to be to the left of center
|
|
279
|
+
impulse[kernel_size // 2] = 1
|
|
280
|
+
kernel = gaussian_filter(
|
|
281
|
+
impulse, laplace_gaussian_sigma, mode="constant", cval=0
|
|
282
|
+
)
|
|
283
|
+
# kernel should be fully symmetric
|
|
284
|
+
assert kernel[0] == kernel[-1]
|
|
285
|
+
gauss_kernel = torch.from_numpy(kernel).type(dtype).to(torch_device)
|
|
286
|
+
|
|
287
|
+
# default shape is (y, x) with y axis filtered only - we transpose
|
|
288
|
+
# input to filter on x
|
|
289
|
+
gauss_kernel = gauss_kernel.view(1, 1, -1, 1)
|
|
290
|
+
|
|
291
|
+
# see https://discuss.pytorch.org/t/performance-issue-for-conv2d-
|
|
292
|
+
# with-1d-filter-along-a-dim/201734. Conv2d is faster on a specific dim
|
|
293
|
+
# for 1D filters depending on CPU/CUDA. See also filter_for_peaks
|
|
294
|
+
# on CPU, we only do conv2d on the (1st) dim
|
|
295
|
+
if torch_device != "cpu":
|
|
296
|
+
# on CUDA, we only filter on the x dim, flipping input to filter y
|
|
297
|
+
gauss_kernel = gauss_kernel.view(1, 1, 1, -1)
|
|
298
|
+
|
|
299
|
+
return gauss_kernel
|
|
300
|
+
|
|
301
|
+
def _get_laplacian_kernel(
|
|
302
|
+
self, torch_device: str, dtype: torch.dtype
|
|
303
|
+
) -> torch.Tensor:
|
|
304
|
+
"""Gets a 2d laplacian kernel, based on scipy's laplace."""
|
|
305
|
+
# for parity with scipy, scipy computes the laplacian with default
|
|
306
|
+
# parameters and kernel size 3 using filter coefficients [1, -2, 1].
|
|
307
|
+
# Each filtered pixel is the sum of the filter around the pixel
|
|
308
|
+
# vertically and horizontally. We can do it in 2d at once with
|
|
309
|
+
# coefficients below (faster than 2x1D for such small filter)
|
|
310
|
+
return torch.as_tensor(
|
|
311
|
+
[[0, 1, 0], [1, -4, 1], [0, 1, 0]],
|
|
312
|
+
dtype=dtype,
|
|
313
|
+
device=torch_device,
|
|
314
|
+
).view(1, 1, 3, 3)
|
|
315
|
+
|
|
316
|
+
def enhance_peaks(self, planes: torch.Tensor) -> torch.Tensor:
|
|
317
|
+
"""
|
|
318
|
+
Applies the filtering and normalization to the 3d z-stack (not inplace)
|
|
319
|
+
and returns the filtered z-stack.
|
|
320
|
+
"""
|
|
321
|
+
if self.torch_device == "cpu" and self.use_scipy:
|
|
322
|
+
filtered_planes = planes.clone()
|
|
323
|
+
for i in range(planes.shape[0]):
|
|
324
|
+
img = planes[i, :, :].numpy()
|
|
325
|
+
img = medfilt2d(img)
|
|
326
|
+
img = gaussian_filter(img, self.laplace_gaussian_sigma)
|
|
327
|
+
img = laplace(img)
|
|
328
|
+
filtered_planes[i, :, :] = torch.from_numpy(img)
|
|
329
|
+
|
|
330
|
+
# laplace makes values negative so flip
|
|
331
|
+
normalize(
|
|
332
|
+
filtered_planes,
|
|
333
|
+
flip=True,
|
|
334
|
+
max_value=self.clipping_value,
|
|
335
|
+
)
|
|
336
|
+
return filtered_planes
|
|
337
|
+
|
|
338
|
+
filtered_planes = filter_for_peaks(
|
|
339
|
+
planes,
|
|
340
|
+
self.med_kernel,
|
|
341
|
+
self.gauss_kernel,
|
|
342
|
+
self.gaussian_filter_size,
|
|
343
|
+
self.lap_kernel,
|
|
344
|
+
self.torch_device,
|
|
345
|
+
self.clipping_value,
|
|
346
|
+
)
|
|
347
|
+
return filtered_planes
|
|
@@ -1,87 +1,169 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
|
-
from
|
|
3
|
-
from typing import Optional, Tuple
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Tuple
|
|
4
3
|
|
|
5
|
-
import
|
|
6
|
-
import numpy as np
|
|
4
|
+
import torch
|
|
7
5
|
|
|
8
|
-
from cellfinder.core import
|
|
9
|
-
from cellfinder.core.detect.filters.plane.classical_filter import enhance_peaks
|
|
6
|
+
from cellfinder.core.detect.filters.plane.classical_filter import PeakEnhancer
|
|
10
7
|
from cellfinder.core.detect.filters.plane.tile_walker import TileWalker
|
|
11
8
|
|
|
12
9
|
|
|
13
10
|
@dataclass
|
|
14
11
|
class TileProcessor:
|
|
15
12
|
"""
|
|
16
|
-
|
|
13
|
+
Processor that filters each plane to highlight the peaks and also
|
|
14
|
+
tiles and thresholds each plane returning a mask indicating which
|
|
15
|
+
tiles are inside the brain.
|
|
16
|
+
|
|
17
|
+
Each input plane is:
|
|
18
|
+
|
|
19
|
+
1. Clipped to [0, clipping_value].
|
|
20
|
+
2. Tiled and compared to the corner tile. Any tile that is "bright"
|
|
21
|
+
according to `TileWalker` is marked as being in the brain.
|
|
22
|
+
3. Filtered
|
|
23
|
+
1. Run through the peak enhancement filter (see `PeakEnhancer`)
|
|
24
|
+
2. Thresholded. Any values that are larger than
|
|
25
|
+
(mean + stddev * n_sds_above_mean_thresh) are set to
|
|
26
|
+
threshold_value.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
17
29
|
----------
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
30
|
+
plane_shape : tuple(int, int)
|
|
31
|
+
Height/width of the planes.
|
|
32
|
+
clipping_value : int
|
|
33
|
+
Upper value that the input planes are clipped to. Result is scaled so
|
|
34
|
+
max is this value.
|
|
35
|
+
threshold_value : int
|
|
21
36
|
Value used to mark bright features in the input planes after they have
|
|
22
37
|
been run through the 2D filter.
|
|
38
|
+
n_sds_above_mean_thresh : float
|
|
39
|
+
Number of standard deviations above the mean threshold to use for
|
|
40
|
+
determining whether a voxel is bright.
|
|
41
|
+
log_sigma_size : float
|
|
42
|
+
Size of the sigma for the gaussian filter.
|
|
43
|
+
soma_diameter : float
|
|
44
|
+
Diameter of the soma in voxels.
|
|
45
|
+
torch_device: str
|
|
46
|
+
The device on which the data and processing occurs on. Can be e.g.
|
|
47
|
+
"cpu", "cuda" etc. Any data passed to the filter must be on this
|
|
48
|
+
device. Returned data will also be on this device.
|
|
49
|
+
dtype : str
|
|
50
|
+
The data-type of the input planes and the type to use internally.
|
|
51
|
+
E.g. "float32".
|
|
52
|
+
use_scipy : bool
|
|
53
|
+
If running on the CPU whether to use the scipy filters or the same
|
|
54
|
+
pytorch filters used on CUDA. Scipy filters can be faster.
|
|
23
55
|
"""
|
|
24
56
|
|
|
57
|
+
# Upper value that the input plane is clipped to. Result is scaled so
|
|
58
|
+
# max is this value
|
|
25
59
|
clipping_value: int
|
|
60
|
+
# Value used to mark bright features in the input planes after they have
|
|
61
|
+
# been run through the 2D filter
|
|
26
62
|
threshold_value: int
|
|
27
|
-
|
|
28
|
-
|
|
63
|
+
# voxels who are this many std above mean or more are set to
|
|
64
|
+
# threshold_value
|
|
29
65
|
n_sds_above_mean_thresh: float
|
|
30
66
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
tiles are inside the brain.
|
|
67
|
+
# filter that finds the peaks in the planes
|
|
68
|
+
peak_enhancer: PeakEnhancer = field(init=False)
|
|
69
|
+
# generates tiles of the planes, with each tile marked as being inside
|
|
70
|
+
# or outside the brain based on brightness
|
|
71
|
+
tile_walker: TileWalker = field(init=False)
|
|
37
72
|
|
|
38
|
-
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
plane_shape: Tuple[int, int],
|
|
76
|
+
clipping_value: int,
|
|
77
|
+
threshold_value: int,
|
|
78
|
+
n_sds_above_mean_thresh: float,
|
|
79
|
+
log_sigma_size: float,
|
|
80
|
+
soma_diameter: int,
|
|
81
|
+
torch_device: str,
|
|
82
|
+
dtype: str,
|
|
83
|
+
use_scipy: bool,
|
|
84
|
+
):
|
|
85
|
+
self.clipping_value = clipping_value
|
|
86
|
+
self.threshold_value = threshold_value
|
|
87
|
+
self.n_sds_above_mean_thresh = n_sds_above_mean_thresh
|
|
39
88
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
89
|
+
laplace_gaussian_sigma = log_sigma_size * soma_diameter
|
|
90
|
+
self.peak_enhancer = PeakEnhancer(
|
|
91
|
+
torch_device=torch_device,
|
|
92
|
+
dtype=getattr(torch, dtype),
|
|
93
|
+
clipping_value=self.clipping_value,
|
|
94
|
+
laplace_gaussian_sigma=laplace_gaussian_sigma,
|
|
95
|
+
use_scipy=use_scipy,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
self.tile_walker = TileWalker(
|
|
99
|
+
plane_shape=plane_shape,
|
|
100
|
+
soma_diameter=soma_diameter,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def get_tile_mask(
|
|
104
|
+
self, planes: torch.Tensor
|
|
105
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
106
|
+
"""
|
|
107
|
+
Applies the filtering listed in the class description.
|
|
45
108
|
|
|
46
109
|
Parameters
|
|
47
110
|
----------
|
|
48
|
-
|
|
49
|
-
Input
|
|
50
|
-
lock :
|
|
51
|
-
If given, block reading the plane into memory until the lock
|
|
52
|
-
can be acquired.
|
|
111
|
+
planes : torch.Tensor
|
|
112
|
+
Input planes (z-stack). Note, the input data is modified.
|
|
53
113
|
|
|
54
114
|
Returns
|
|
55
115
|
-------
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
inside_brain_tiles :
|
|
116
|
+
planes : torch.Tensor
|
|
117
|
+
Filtered and thresholded planes (z-stack).
|
|
118
|
+
inside_brain_tiles : torch.Tensor
|
|
59
119
|
Boolean mask indicating which tiles are inside (1) or
|
|
60
120
|
outside (0) the brain.
|
|
121
|
+
It's a z-stack whose planes are the shape of the number of tiles
|
|
122
|
+
in each planar axis.
|
|
61
123
|
"""
|
|
62
|
-
|
|
63
|
-
plane = plane.T
|
|
64
|
-
np.clip(plane, 0, self.clipping_value, out=plane)
|
|
65
|
-
if lock is not None:
|
|
66
|
-
lock.acquire(blocking=True)
|
|
67
|
-
# Read plane from a dask array into memory as a numpy array
|
|
68
|
-
if isinstance(plane, da.Array):
|
|
69
|
-
plane = np.array(plane)
|
|
70
|
-
|
|
124
|
+
torch.clip_(planes, 0, self.clipping_value)
|
|
71
125
|
# Get tiles that are within the brain
|
|
72
|
-
|
|
73
|
-
walker.mark_bright_tiles()
|
|
74
|
-
inside_brain_tiles = walker.bright_tiles_mask
|
|
75
|
-
|
|
126
|
+
inside_brain_tiles = self.tile_walker.get_bright_tiles(planes)
|
|
76
127
|
# Threshold the image
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
128
|
+
enhanced_planes = self.peak_enhancer.enhance_peaks(planes)
|
|
129
|
+
|
|
130
|
+
_threshold_planes(
|
|
131
|
+
planes,
|
|
132
|
+
enhanced_planes,
|
|
133
|
+
self.n_sds_above_mean_thresh,
|
|
134
|
+
self.threshold_value,
|
|
81
135
|
)
|
|
82
|
-
avg = np.mean(thresholded_img)
|
|
83
|
-
sd = np.std(thresholded_img)
|
|
84
|
-
threshold = avg + self.n_sds_above_mean_thresh * sd
|
|
85
|
-
plane[thresholded_img > threshold] = self.threshold_value
|
|
86
136
|
|
|
87
|
-
return
|
|
137
|
+
return planes, inside_brain_tiles
|
|
138
|
+
|
|
139
|
+
def get_tiled_buffer(self, depth: int, device: str):
|
|
140
|
+
return self.tile_walker.get_tiled_buffer(depth, device)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@torch.jit.script
|
|
144
|
+
def _threshold_planes(
|
|
145
|
+
planes: torch.Tensor,
|
|
146
|
+
enhanced_planes: torch.Tensor,
|
|
147
|
+
n_sds_above_mean_thresh: float,
|
|
148
|
+
threshold_value: int,
|
|
149
|
+
) -> None:
|
|
150
|
+
"""
|
|
151
|
+
Sets each plane (in-place) to threshold_value, where the corresponding
|
|
152
|
+
enhanced_plane > mean + n_sds_above_mean_thresh*std. Each plane will be
|
|
153
|
+
set to zero elsewhere.
|
|
154
|
+
"""
|
|
155
|
+
planes_1d = enhanced_planes.view(enhanced_planes.shape[0], -1)
|
|
156
|
+
|
|
157
|
+
# add back last dim
|
|
158
|
+
avg = torch.mean(planes_1d, dim=1, keepdim=True).unsqueeze(2)
|
|
159
|
+
sd = torch.std(planes_1d, dim=1, keepdim=True).unsqueeze(2)
|
|
160
|
+
threshold = avg + n_sds_above_mean_thresh * sd
|
|
161
|
+
|
|
162
|
+
above = enhanced_planes > threshold
|
|
163
|
+
planes[above] = threshold_value
|
|
164
|
+
# subsequent steps only care about the values that are set to threshold or
|
|
165
|
+
# above in planes. We set values in *planes* to threshold based on the
|
|
166
|
+
# value in *enhanced_planes*. So, there could be values in planes that are
|
|
167
|
+
# at threshold already, but in enhanced_planes they are not. So it's best
|
|
168
|
+
# to zero all other values, so voxels previously at threshold don't count
|
|
169
|
+
planes[torch.logical_not(above)] = 0
|