cellfinder 1.3.2__py3-none-any.whl → 1.4.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,45 +1,347 @@
1
1
  import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
2
4
  from scipy.ndimage import gaussian_filter, laplace
3
5
  from scipy.signal import medfilt2d
4
6
 
5
7
 
6
- def enhance_peaks(
7
- img: np.ndarray, clipping_value: float, gaussian_sigma: float = 2.5
8
- ) -> np.ndarray:
8
+ @torch.jit.script
9
+ def normalize(
10
+ filtered_planes: torch.Tensor,
11
+ flip: bool,
12
+ max_value: float = 1.0,
13
+ ) -> None:
9
14
  """
10
- Enhances the peaks (bright pixels) in an input image.
15
+ Normalizes the 3d tensor so each z-plane is independently scaled to be
16
+ in the [0, max_value] range. If `flip` is `True`, the sign of the tensor
17
+ values are flipped before any processing.
11
18
 
12
- Parameters:
19
+ It is done to filtered_planes inplace.
20
+ """
21
+ num_z = filtered_planes.shape[0]
22
+ filtered_planes_1d = filtered_planes.view(num_z, -1)
23
+
24
+ if flip:
25
+ filtered_planes_1d.mul_(-1)
26
+
27
+ planes_min = torch.min(filtered_planes_1d, dim=1, keepdim=True)[0]
28
+ filtered_planes_1d.sub_(planes_min)
29
+ # take max after subtraction
30
+ planes_max = torch.max(filtered_planes_1d, dim=1, keepdim=True)[0]
31
+ # if min = max = zero, divide by 1 - it'll stay zero
32
+ planes_max[planes_max == 0] = 1
33
+ filtered_planes_1d.div_(planes_max)
34
+
35
+ if max_value != 1.0:
36
+ # To leave room to label in the 3d detection.
37
+ filtered_planes_1d.mul_(max_value)
38
+
39
+
40
+ @torch.jit.script
41
+ def filter_for_peaks(
42
+ planes: torch.Tensor,
43
+ med_kernel: torch.Tensor,
44
+ gauss_kernel: torch.Tensor,
45
+ gauss_kernel_size: int,
46
+ lap_kernel: torch.Tensor,
47
+ device: str,
48
+ clipping_value: float,
49
+ ) -> torch.Tensor:
50
+ """
51
+ Takes the 3d z-stack and returns a new z-stack where the peaks are
52
+ highlighted.
53
+
54
+ It applies a median filter -> gaussian filter -> laplacian filter.
55
+ """
56
+ filtered_planes = planes.unsqueeze(1) # ZYX -> ZCYX input, C=channels
57
+
58
+ # ------------------ median filter ------------------
59
+ # extracts patches to compute median over for each pixel
60
+ # We go from ZCYX -> ZCYX, C=1 to C=9 with C containing the elements around
61
+ # each Z,X,Y voxel over which we compute the median
62
+ # Zero padding is ok here
63
+ filtered_planes = F.conv2d(filtered_planes, med_kernel, padding="same")
64
+ # we're going back to ZCYX=Z1YX by taking median of patches in C dim
65
+ filtered_planes = filtered_planes.median(dim=1, keepdim=True)[0]
66
+
67
+ # ------------------ gaussian filter ------------------
68
+ # normalize the input data to 0-1 range. Otherwise, if the values are
69
+ # large, we'd need a float64 so conv result is accurate
70
+ normalize(filtered_planes, flip=False)
71
+
72
+ # we need to do reflection padding around the tensor for parity with scipy
73
+ # gaussian filtering. Scipy does reflection in a manner typically called
74
+ # symmetric: (dcba|abcd|dcba). Torch does it like this: (dcb|abcd|cba). So
75
+ # we manually do symmetric padding below
76
+ pad = gauss_kernel_size // 2
77
+ padding_mode = "reflect"
78
+ # if data is too small for reflect, just use constant border value
79
+ if pad >= filtered_planes.shape[-1] or pad >= filtered_planes.shape[-2]:
80
+ padding_mode = "replicate"
81
+ filtered_planes = F.pad(filtered_planes, (pad,) * 4, padding_mode, 0.0)
82
+ # We reflected torch style, so copy/shift everything by one to be symmetric
83
+ filtered_planes[:, :, :pad, :] = filtered_planes[
84
+ :, :, 1 : pad + 1, :
85
+ ].clone()
86
+ filtered_planes[:, :, -pad:, :] = filtered_planes[
87
+ :, :, -pad - 1 : -1, :
88
+ ].clone()
89
+ filtered_planes[:, :, :, :pad] = filtered_planes[
90
+ :, :, :, 1 : pad + 1
91
+ ].clone()
92
+ filtered_planes[:, :, :, -pad:] = filtered_planes[
93
+ :, :, :, -pad - 1 : -1
94
+ ].clone()
95
+
96
+ # We apply the 1D gaussian filter twice, once for Y and once for X. The
97
+ # filter shape passed in is 11K1 or 111K, depending on device. Where
98
+ # K=filter size
99
+ # see https://discuss.pytorch.org/t/performance-issue-for-conv2d-with-1d-
100
+ # filter-along-a-dim/201734/2 for the reason for the moveaxis depending
101
+ # on the device
102
+ if device == "cpu":
103
+ # kernel shape is 11K1. First do Y (second to last axis)
104
+ filtered_planes = F.conv2d(
105
+ filtered_planes, gauss_kernel, padding="valid"
106
+ )
107
+ # To do X, exchange X,Y axis, filter, change back. On CPU, Y (second
108
+ # to last) axis is faster.
109
+ filtered_planes = F.conv2d(
110
+ filtered_planes.moveaxis(-1, -2), gauss_kernel, padding="valid"
111
+ ).moveaxis(-1, -2)
112
+ else:
113
+ # kernel shape is 111K
114
+ # First do Y (second to last axis). Exchange X,Y axis, filter, change
115
+ # back. On CUDA, X (last) axis is faster.
116
+ filtered_planes = F.conv2d(
117
+ filtered_planes.moveaxis(-1, -2), gauss_kernel, padding="valid"
118
+ ).moveaxis(-1, -2)
119
+ # now do X, last axis
120
+ filtered_planes = F.conv2d(
121
+ filtered_planes, gauss_kernel, padding="valid"
122
+ )
123
+
124
+ # ------------------ laplacian filter ------------------
125
+ # it's a 2d filter. Need to pad using symmetric for scipy parity. But,
126
+ # torch doesn't have it, and we used a kernel of size 3, so for padding of
127
+ # 1, replicate == symmetric. That's enough for parity with past scipy. If
128
+ # we change kernel size in the future, we may have to do as above
129
+ padding = lap_kernel.shape[-1] // 2
130
+ filtered_planes = F.pad(filtered_planes, (padding,) * 4, "replicate")
131
+ filtered_planes = F.conv2d(filtered_planes, lap_kernel, padding="valid")
132
+
133
+ # we don't need the channel axis
134
+ filtered_planes = filtered_planes[:, 0, :, :]
135
+
136
+ # scale back to full scale, filtered values are negative so flip
137
+ normalize(filtered_planes, flip=True, max_value=clipping_value)
138
+ return filtered_planes
139
+
140
+
141
+ class PeakEnhancer:
142
+ """
143
+ A class that filters each plane in a z-stack such that peaks are
144
+ visualized.
145
+
146
+ It uses a series of 2D filters of median -> gaussian ->
147
+ laplacian. Then normalizes each plane to be between [0, clipping_value].
148
+
149
+ Parameters
13
150
  ----------
14
- img : np.ndarray
15
- Input image.
16
- clipping_value : float
17
- Maximum value for the enhanced image.
18
- gaussian_sigma : float, optional
19
- Standard deviation for the Gaussian filter. Default is 2.5.
20
-
21
- Returns:
22
- -------
23
- np.ndarray
24
- Enhanced image with peaks.
25
-
26
- Notes:
27
- ------
28
- The enhancement process includes the following steps:
29
- 1. Applying a 2D median filter.
30
- 2. Applying a Laplacian of Gaussian filter (LoG).
31
- 3. Multiplying by -1 (bright spots respond negative in a LoG).
32
- 4. Rescaling image values to range from 0 to clipping value.
33
- """
34
- type_in = img.dtype
35
- filtered_img = medfilt2d(img.astype(np.float64))
36
- filtered_img = gaussian_filter(filtered_img, gaussian_sigma)
37
- filtered_img = laplace(filtered_img)
38
- filtered_img *= -1
39
-
40
- filtered_img -= filtered_img.min()
41
- filtered_img /= filtered_img.max()
42
-
43
- # To leave room to label in the 3d detection.
44
- filtered_img *= clipping_value
45
- return filtered_img.astype(type_in)
151
+ torch_device: str
152
+ The device on which the data and processing occurs on. Can be e.g.
153
+ "cpu", "cuda" etc. Any data passed to the filter must be on this
154
+ device. Returned data will also be on this device.
155
+ dtype : torch.dtype
156
+ The data-type of the input planes and the type to use internally.
157
+ E.g. `torch.float32`.
158
+ clipping_value : int
159
+ The value such that after normalizing, the max value will be this
160
+ clipping_value.
161
+ laplace_gaussian_sigma : float
162
+ Size of the sigma for the gaussian filter.
163
+ use_scipy : bool
164
+ If running on the CPU whether to use the scipy filters or the same
165
+ pytorch filters used on CUDA. Scipy filters can be faster.
166
+ """
167
+
168
+ # binary kernel that generates square patches for each pixel so we can find
169
+ # the median around the pixel
170
+ med_kernel: torch.Tensor
171
+
172
+ # gaussian 1D kernel with kernel/weight shape 11K1 or 111K, depending
173
+ # on device. Where K=filter size
174
+ gauss_kernel: torch.Tensor
175
+
176
+ # 2D laplacian kernel with kernel/weight shape KxK. Where
177
+ # K=filter size
178
+ lap_kernel: torch.Tensor
179
+
180
+ # the value such that after normalizing, the max value will be this
181
+ # clipping_value
182
+ clipping_value: float
183
+
184
+ # sigma value for gaussian filter
185
+ laplace_gaussian_sigma: float
186
+
187
+ # the torch device to run on. E.g. cpu/cuda.
188
+ torch_device: str
189
+
190
+ # when running on CPU whether to use pytorch or scipy for filters
191
+ use_scipy: bool
192
+
193
+ median_filter_size: int = 3
194
+ """
195
+ The median filter size in x/y direction.
196
+
197
+ **Must** be odd.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ torch_device: str,
203
+ dtype: torch.dtype,
204
+ clipping_value: float,
205
+ laplace_gaussian_sigma: float,
206
+ use_scipy: bool,
207
+ ):
208
+ super().__init__()
209
+ self.torch_device = torch_device.lower()
210
+ self.clipping_value = clipping_value
211
+ self.laplace_gaussian_sigma = laplace_gaussian_sigma
212
+ self.use_scipy = use_scipy
213
+
214
+ # all these kernels are odd in size
215
+ self.med_kernel = self._get_median_kernel(torch_device, dtype)
216
+ self.gauss_kernel = self._get_gaussian_kernel(
217
+ torch_device, dtype, laplace_gaussian_sigma
218
+ )
219
+ self.lap_kernel = self._get_laplacian_kernel(torch_device, dtype)
220
+
221
+ @property
222
+ def gaussian_filter_size(self) -> int:
223
+ """
224
+ The gaussian filter 1d size.
225
+
226
+ It is odd.
227
+ """
228
+ return 2 * int(round(4 * self.laplace_gaussian_sigma)) + 1
229
+
230
+ def _get_median_kernel(
231
+ self, torch_device: str, dtype: torch.dtype
232
+ ) -> torch.Tensor:
233
+ """
234
+ Gets a median patch generator kernel, already on the correct
235
+ device.
236
+
237
+ Based on how kornia does it for median filtering.
238
+ """
239
+ # must be odd kernel
240
+ kernel_n = self.median_filter_size
241
+ if not (kernel_n % 2):
242
+ raise ValueError("The median filter size must be odd")
243
+
244
+ # extract patches to compute median over for each pixel. When passing
245
+ # input we go from ZCYX -> ZCYX, C=1 to C=9 and containing the elements
246
+ # around each Z,X,Y over which we can then compute the median
247
+ window_range = kernel_n * kernel_n # e.g. 3x3
248
+ kernel = torch.zeros(
249
+ (window_range, window_range), device=torch_device, dtype=dtype
250
+ )
251
+ idx = torch.arange(window_range, device=torch_device)
252
+ # diagonal of e.g. 9x9 is 1
253
+ kernel[idx, idx] = 1.0
254
+ # out channels, in channels, n*y, n*x. The kernel collects all the 3x3
255
+ # elements around a pixel, using a binary mask for each element, as a
256
+ # separate channel. So we go from 1 to 9 channels in the output
257
+ kernel = kernel.view(window_range, 1, kernel_n, kernel_n)
258
+
259
+ return kernel
260
+
261
+ def _get_gaussian_kernel(
262
+ self,
263
+ torch_device: str,
264
+ dtype: torch.dtype,
265
+ laplace_gaussian_sigma: float,
266
+ ) -> torch.Tensor:
267
+ """Gets the 1D gaussian kernel used to filter the data."""
268
+ # we do 2 1D filters, once on each y, x dim.
269
+ # shape of kernel will be 11K1 with dims Z, C, Y, X. C=1, Z is expanded
270
+ # to number of z during filtering.
271
+ kernel_size = self.gaussian_filter_size
272
+
273
+ # to get the values of a 1D gaussian kernel, we pass a single impulse
274
+ # data through the filter, which recovers the filter values. We do this
275
+ # because scipy doesn't make their kernel available in public API and
276
+ # we want parity with scipy filtering
277
+ impulse = np.zeros(kernel_size)
278
+ # the impulse needs to be to the left of center
279
+ impulse[kernel_size // 2] = 1
280
+ kernel = gaussian_filter(
281
+ impulse, laplace_gaussian_sigma, mode="constant", cval=0
282
+ )
283
+ # kernel should be fully symmetric
284
+ assert kernel[0] == kernel[-1]
285
+ gauss_kernel = torch.from_numpy(kernel).type(dtype).to(torch_device)
286
+
287
+ # default shape is (y, x) with y axis filtered only - we transpose
288
+ # input to filter on x
289
+ gauss_kernel = gauss_kernel.view(1, 1, -1, 1)
290
+
291
+ # see https://discuss.pytorch.org/t/performance-issue-for-conv2d-
292
+ # with-1d-filter-along-a-dim/201734. Conv2d is faster on a specific dim
293
+ # for 1D filters depending on CPU/CUDA. See also filter_for_peaks
294
+ # on CPU, we only do conv2d on the (1st) dim
295
+ if torch_device != "cpu":
296
+ # on CUDA, we only filter on the x dim, flipping input to filter y
297
+ gauss_kernel = gauss_kernel.view(1, 1, 1, -1)
298
+
299
+ return gauss_kernel
300
+
301
+ def _get_laplacian_kernel(
302
+ self, torch_device: str, dtype: torch.dtype
303
+ ) -> torch.Tensor:
304
+ """Gets a 2d laplacian kernel, based on scipy's laplace."""
305
+ # for parity with scipy, scipy computes the laplacian with default
306
+ # parameters and kernel size 3 using filter coefficients [1, -2, 1].
307
+ # Each filtered pixel is the sum of the filter around the pixel
308
+ # vertically and horizontally. We can do it in 2d at once with
309
+ # coefficients below (faster than 2x1D for such small filter)
310
+ return torch.as_tensor(
311
+ [[0, 1, 0], [1, -4, 1], [0, 1, 0]],
312
+ dtype=dtype,
313
+ device=torch_device,
314
+ ).view(1, 1, 3, 3)
315
+
316
+ def enhance_peaks(self, planes: torch.Tensor) -> torch.Tensor:
317
+ """
318
+ Applies the filtering and normalization to the 3d z-stack (not inplace)
319
+ and returns the filtered z-stack.
320
+ """
321
+ if self.torch_device == "cpu" and self.use_scipy:
322
+ filtered_planes = planes.clone()
323
+ for i in range(planes.shape[0]):
324
+ img = planes[i, :, :].numpy()
325
+ img = medfilt2d(img)
326
+ img = gaussian_filter(img, self.laplace_gaussian_sigma)
327
+ img = laplace(img)
328
+ filtered_planes[i, :, :] = torch.from_numpy(img)
329
+
330
+ # laplace makes values negative so flip
331
+ normalize(
332
+ filtered_planes,
333
+ flip=True,
334
+ max_value=self.clipping_value,
335
+ )
336
+ return filtered_planes
337
+
338
+ filtered_planes = filter_for_peaks(
339
+ planes,
340
+ self.med_kernel,
341
+ self.gauss_kernel,
342
+ self.gaussian_filter_size,
343
+ self.lap_kernel,
344
+ self.torch_device,
345
+ self.clipping_value,
346
+ )
347
+ return filtered_planes
@@ -1,87 +1,169 @@
1
- from dataclasses import dataclass
2
- from threading import Lock
3
- from typing import Optional, Tuple
1
+ from dataclasses import dataclass, field
2
+ from typing import Tuple
4
3
 
5
- import dask.array as da
6
- import numpy as np
4
+ import torch
7
5
 
8
- from cellfinder.core import types
9
- from cellfinder.core.detect.filters.plane.classical_filter import enhance_peaks
6
+ from cellfinder.core.detect.filters.plane.classical_filter import PeakEnhancer
10
7
  from cellfinder.core.detect.filters.plane.tile_walker import TileWalker
11
8
 
12
9
 
13
10
  @dataclass
14
11
  class TileProcessor:
15
12
  """
16
- Attributes
13
+ Processor that filters each plane to highlight the peaks and also
14
+ tiles and thresholds each plane returning a mask indicating which
15
+ tiles are inside the brain.
16
+
17
+ Each input plane is:
18
+
19
+ 1. Clipped to [0, clipping_value].
20
+ 2. Tiled and compared to the corner tile. Any tile that is "bright"
21
+ according to `TileWalker` is marked as being in the brain.
22
+ 3. Filtered
23
+ 1. Run through the peak enhancement filter (see `PeakEnhancer`)
24
+ 2. Thresholded. Any values that are larger than
25
+ (mean + stddev * n_sds_above_mean_thresh) are set to
26
+ threshold_value.
27
+
28
+ Parameters
17
29
  ----------
18
- clipping_value :
19
- Upper value that the input plane is clipped to.
20
- threshold_value :
30
+ plane_shape : tuple(int, int)
31
+ Height/width of the planes.
32
+ clipping_value : int
33
+ Upper value that the input planes are clipped to. Result is scaled so
34
+ max is this value.
35
+ threshold_value : int
21
36
  Value used to mark bright features in the input planes after they have
22
37
  been run through the 2D filter.
38
+ n_sds_above_mean_thresh : float
39
+ Number of standard deviations above the mean threshold to use for
40
+ determining whether a voxel is bright.
41
+ log_sigma_size : float
42
+ Size of the sigma for the gaussian filter.
43
+ soma_diameter : float
44
+ Diameter of the soma in voxels.
45
+ torch_device: str
46
+ The device on which the data and processing occurs on. Can be e.g.
47
+ "cpu", "cuda" etc. Any data passed to the filter must be on this
48
+ device. Returned data will also be on this device.
49
+ dtype : str
50
+ The data-type of the input planes and the type to use internally.
51
+ E.g. "float32".
52
+ use_scipy : bool
53
+ If running on the CPU whether to use the scipy filters or the same
54
+ pytorch filters used on CUDA. Scipy filters can be faster.
23
55
  """
24
56
 
57
+ # Upper value that the input plane is clipped to. Result is scaled so
58
+ # max is this value
25
59
  clipping_value: int
60
+ # Value used to mark bright features in the input planes after they have
61
+ # been run through the 2D filter
26
62
  threshold_value: int
27
- soma_diameter: int
28
- log_sigma_size: float
63
+ # voxels who are this many std above mean or more are set to
64
+ # threshold_value
29
65
  n_sds_above_mean_thresh: float
30
66
 
31
- def get_tile_mask(
32
- self, plane: types.array, lock: Optional[Lock] = None
33
- ) -> Tuple[np.ndarray, np.ndarray]:
34
- """
35
- This thresholds the input plane, and returns a mask indicating which
36
- tiles are inside the brain.
67
+ # filter that finds the peaks in the planes
68
+ peak_enhancer: PeakEnhancer = field(init=False)
69
+ # generates tiles of the planes, with each tile marked as being inside
70
+ # or outside the brain based on brightness
71
+ tile_walker: TileWalker = field(init=False)
37
72
 
38
- The input plane is:
73
+ def __init__(
74
+ self,
75
+ plane_shape: Tuple[int, int],
76
+ clipping_value: int,
77
+ threshold_value: int,
78
+ n_sds_above_mean_thresh: float,
79
+ log_sigma_size: float,
80
+ soma_diameter: int,
81
+ torch_device: str,
82
+ dtype: str,
83
+ use_scipy: bool,
84
+ ):
85
+ self.clipping_value = clipping_value
86
+ self.threshold_value = threshold_value
87
+ self.n_sds_above_mean_thresh = n_sds_above_mean_thresh
39
88
 
40
- 1. Clipped to [0, self.clipping_value]
41
- 2. Run through a peak enhancement filter (see `classical_filter.py`)
42
- 3. Thresholded. Any values that are larger than
43
- (mean + stddev * self.n_sds_above_mean_thresh) are set to
44
- self.threshold_value in-place.
89
+ laplace_gaussian_sigma = log_sigma_size * soma_diameter
90
+ self.peak_enhancer = PeakEnhancer(
91
+ torch_device=torch_device,
92
+ dtype=getattr(torch, dtype),
93
+ clipping_value=self.clipping_value,
94
+ laplace_gaussian_sigma=laplace_gaussian_sigma,
95
+ use_scipy=use_scipy,
96
+ )
97
+
98
+ self.tile_walker = TileWalker(
99
+ plane_shape=plane_shape,
100
+ soma_diameter=soma_diameter,
101
+ )
102
+
103
+ def get_tile_mask(
104
+ self, planes: torch.Tensor
105
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
106
+ """
107
+ Applies the filtering listed in the class description.
45
108
 
46
109
  Parameters
47
110
  ----------
48
- plane :
49
- Input plane.
50
- lock :
51
- If given, block reading the plane into memory until the lock
52
- can be acquired.
111
+ planes : torch.Tensor
112
+ Input planes (z-stack). Note, the input data is modified.
53
113
 
54
114
  Returns
55
115
  -------
56
- plane :
57
- Thresholded plane.
58
- inside_brain_tiles :
116
+ planes : torch.Tensor
117
+ Filtered and thresholded planes (z-stack).
118
+ inside_brain_tiles : torch.Tensor
59
119
  Boolean mask indicating which tiles are inside (1) or
60
120
  outside (0) the brain.
121
+ It's a z-stack whose planes are the shape of the number of tiles
122
+ in each planar axis.
61
123
  """
62
- laplace_gaussian_sigma = self.log_sigma_size * self.soma_diameter
63
- plane = plane.T
64
- np.clip(plane, 0, self.clipping_value, out=plane)
65
- if lock is not None:
66
- lock.acquire(blocking=True)
67
- # Read plane from a dask array into memory as a numpy array
68
- if isinstance(plane, da.Array):
69
- plane = np.array(plane)
70
-
124
+ torch.clip_(planes, 0, self.clipping_value)
71
125
  # Get tiles that are within the brain
72
- walker = TileWalker(plane, self.soma_diameter)
73
- walker.mark_bright_tiles()
74
- inside_brain_tiles = walker.bright_tiles_mask
75
-
126
+ inside_brain_tiles = self.tile_walker.get_bright_tiles(planes)
76
127
  # Threshold the image
77
- thresholded_img = enhance_peaks(
78
- plane.copy(),
79
- self.clipping_value,
80
- gaussian_sigma=laplace_gaussian_sigma,
128
+ enhanced_planes = self.peak_enhancer.enhance_peaks(planes)
129
+
130
+ _threshold_planes(
131
+ planes,
132
+ enhanced_planes,
133
+ self.n_sds_above_mean_thresh,
134
+ self.threshold_value,
81
135
  )
82
- avg = np.mean(thresholded_img)
83
- sd = np.std(thresholded_img)
84
- threshold = avg + self.n_sds_above_mean_thresh * sd
85
- plane[thresholded_img > threshold] = self.threshold_value
86
136
 
87
- return plane, inside_brain_tiles
137
+ return planes, inside_brain_tiles
138
+
139
+ def get_tiled_buffer(self, depth: int, device: str):
140
+ return self.tile_walker.get_tiled_buffer(depth, device)
141
+
142
+
143
+ @torch.jit.script
144
+ def _threshold_planes(
145
+ planes: torch.Tensor,
146
+ enhanced_planes: torch.Tensor,
147
+ n_sds_above_mean_thresh: float,
148
+ threshold_value: int,
149
+ ) -> None:
150
+ """
151
+ Sets each plane (in-place) to threshold_value, where the corresponding
152
+ enhanced_plane > mean + n_sds_above_mean_thresh*std. Each plane will be
153
+ set to zero elsewhere.
154
+ """
155
+ planes_1d = enhanced_planes.view(enhanced_planes.shape[0], -1)
156
+
157
+ # add back last dim
158
+ avg = torch.mean(planes_1d, dim=1, keepdim=True).unsqueeze(2)
159
+ sd = torch.std(planes_1d, dim=1, keepdim=True).unsqueeze(2)
160
+ threshold = avg + n_sds_above_mean_thresh * sd
161
+
162
+ above = enhanced_planes > threshold
163
+ planes[above] = threshold_value
164
+ # subsequent steps only care about the values that are set to threshold or
165
+ # above in planes. We set values in *planes* to threshold based on the
166
+ # value in *enhanced_planes*. So, there could be values in planes that are
167
+ # at threshold already, but in enhanced_planes they are not. So it's best
168
+ # to zero all other values, so voxels previously at threshold don't count
169
+ planes[torch.logical_not(above)] = 0