cellfinder 1.3.3__py3-none-any.whl → 1.4.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,29 +1,35 @@
1
+ import math
1
2
  from functools import lru_cache
3
+ from typing import Optional
2
4
 
3
5
  import numpy as np
4
- from numba import njit, objmode, prange
5
- from numba.core import types
6
- from numba.experimental import jitclass
6
+ import torch
7
+ import torch.nn.functional as F
7
8
 
8
9
  from cellfinder.core.tools.array_operations import bin_mean_3d
9
10
  from cellfinder.core.tools.geometry import make_sphere
10
11
 
11
- DEBUG = False
12
12
 
13
- uint32_3d_type = types.uint32[:, :, :]
14
- bool_3d_type = types.bool_[:, :, :]
15
- float_3d_type = types.float64[:, :, :]
13
+ class InvalidVolume(ValueError):
14
+ """
15
+ Raised when the volume passed to BallFilter is too small or does not meet
16
+ requirements.
17
+ """
18
+
19
+ pass
16
20
 
17
21
 
18
22
  @lru_cache(maxsize=50)
19
23
  def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray:
20
- # Create a spherical kernel.
21
- #
22
- # This is done by:
23
- # 1. Generating a binary sphere at a resolution *upscale_factor* larger
24
- # than desired.
25
- # 2. Downscaling the binary sphere to get a 'fuzzy' sphere at the
26
- # original intended scale
24
+ """
25
+ Create a spherical kernel.
26
+
27
+ This is done by:
28
+ 1. Generating a binary sphere at a resolution *upscale_factor* larger
29
+ than desired.
30
+ 2. Downscaling the binary sphere to get a 'fuzzy' sphere at the
31
+ original intended scale
32
+ """
27
33
  upscale_factor: int = 7
28
34
  upscaled_kernel_shape = (
29
35
  upscale_factor * ball_xy_size,
@@ -42,11 +48,11 @@ def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray:
42
48
  upscaled_ball_radius,
43
49
  upscaled_ball_centre_position,
44
50
  )
45
- sphere_kernel = sphere_kernel.astype(np.float64)
51
+ sphere_kernel = sphere_kernel.astype(np.float32)
46
52
  kernel = bin_mean_3d(
47
53
  sphere_kernel,
48
- bin_height=upscale_factor,
49
54
  bin_width=upscale_factor,
55
+ bin_height=upscale_factor,
50
56
  bin_depth=upscale_factor,
51
57
  )
52
58
 
@@ -59,359 +65,351 @@ def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray:
59
65
  return kernel
60
66
 
61
67
 
62
- # volume indices/size is 64 bit for very large brains(!)
63
- spec = [
64
- ("ball_xy_size", types.uint32),
65
- ("ball_z_size", types.uint32),
66
- ("tile_step_width", types.uint64),
67
- ("tile_step_height", types.uint64),
68
- ("THRESHOLD_VALUE", types.uint32),
69
- ("SOMA_CENTRE_VALUE", types.uint32),
70
- ("overlap_fraction", types.float64),
71
- ("overlap_threshold", types.float64),
72
- ("middle_z_idx", types.uint32),
73
- ("_num_z_added", types.uint32),
74
- ("kernel", float_3d_type),
75
- ("volume", uint32_3d_type),
76
- ("inside_brain_tiles", bool_3d_type),
77
- ]
78
-
79
-
80
- @jitclass(spec=spec)
81
68
  class BallFilter:
82
69
  """
83
70
  A 3D ball filter.
84
71
 
85
- This runs a spherical kernel across the (x, y) dimensions
72
+ This runs a spherical kernel across the 2d planar dimensions
86
73
  of a *ball_z_size* stack of planes, and marks pixels in the middle
87
- plane of the stack that have a high enough intensity within the
88
- spherical kernel.
74
+ plane of the stack that have a high enough intensity over the
75
+ the spherical kernel.
76
+
77
+ Parameters
78
+ ----------
79
+ plane_height, plane_width : int
80
+ Height/width of the planes.
81
+ ball_xy_size : int
82
+ Diameter of the spherical kernel in the x/y dimensions.
83
+ ball_z_size : int
84
+ Diameter of the spherical kernel in the z dimension.
85
+ Equal to the number of planes stacked to filter
86
+ the central plane of the stack.
87
+ overlap_fraction : float
88
+ The fraction of pixels within the spherical kernel that
89
+ have to be over *threshold_value* for a pixel to be marked
90
+ as having a high intensity.
91
+ threshold_value : int
92
+ Value above which an individual pixel is considered to have
93
+ a high intensity.
94
+ soma_centre_value : int
95
+ Value used to mark pixels with a high enough intensity.
96
+ tile_height, tile_width : int
97
+ Width/height of individual tiles in the mask generated by
98
+ 2D filtering.
99
+ dtype : str
100
+ The data-type of the input planes and the type to use internally.
101
+ E.g. "float32".
102
+ batch_size: int
103
+ The number of planes that will be typically passed in a single batch to
104
+ `append`. This is only used to calculate `num_batches_before_ready`.
105
+ Defaults to 1.
106
+ torch_device: str
107
+ The device on which the data and processing occurs on. Can be e.g.
108
+ "cpu", "cuda" etc. Defaults to "cpu". Any data passed to the filter
109
+ must be on this device. Returned data will also be on this device.
110
+ use_mask : bool
111
+ Whether tiling masks will be used in `append`. If False, tile masks
112
+ won't be passed in and/or will be ignored. Defaults to True.
113
+ """
114
+
115
+ num_batches_before_ready: int
89
116
  """
117
+ The number of batches of size `batch_size` passed to `append`
118
+ before `ready` would return True.
119
+ """
120
+
121
+ # the inside brain tiled mask, if tiles are used (use_mask is True)
122
+ inside_brain_tiles: Optional[torch.Tensor] = None
90
123
 
91
124
  def __init__(
92
125
  self,
93
- plane_width: int,
94
126
  plane_height: int,
127
+ plane_width: int,
95
128
  ball_xy_size: int,
96
129
  ball_z_size: int,
97
130
  overlap_fraction: float,
98
- tile_step_width: int,
99
- tile_step_height: int,
100
131
  threshold_value: int,
101
132
  soma_centre_value: int,
133
+ tile_height: int,
134
+ tile_width: int,
135
+ dtype: str,
136
+ batch_size: int = 1,
137
+ torch_device: str = "cpu",
138
+ use_mask: bool = True,
102
139
  ):
103
- """
104
- Parameters
105
- ----------
106
- plane_width, plane_height :
107
- Width/height of the planes.
108
- ball_xy_size :
109
- Diameter of the spherical kernel in the x/y dimensions.
110
- ball_z_size :
111
- Diameter of the spherical kernel in the z dimension.
112
- Equal to the number of planes that stacked to filter
113
- the central plane of the stack.
114
- overlap_fraction :
115
- The fraction of pixels within the spherical kernel that
116
- have to be over *threshold_value* for a pixel to be marked
117
- as having a high intensity.
118
- tile_step_width, tile_step_height :
119
- Width/height of individual tiles in the mask generated by
120
- 2D filtering.
121
- threshold_value :
122
- Value above which an individual pixel is considered to have
123
- a high intensity.
124
- soma_centre_value :
125
- Value used to mark pixels with a high enough intensity.
126
- """
127
140
  self.ball_xy_size = ball_xy_size
128
141
  self.ball_z_size = ball_z_size
129
142
  self.overlap_fraction = overlap_fraction
130
- self.tile_step_width = tile_step_width
131
- self.tile_step_height = tile_step_height
143
+ self.tile_step_height = tile_height
144
+ self.tile_step_width = tile_width
145
+
146
+ d1 = plane_height
147
+ d2 = plane_width
148
+ ball_xy_size = self.ball_xy_size
149
+ if d1 < ball_xy_size or d2 < ball_xy_size:
150
+ raise InvalidVolume(
151
+ f"Invalid plane size {d1}x{d2}. Needs to be at least "
152
+ f"{ball_xy_size} in each dimension"
153
+ )
132
154
 
133
155
  self.THRESHOLD_VALUE = threshold_value
134
156
  self.SOMA_CENTRE_VALUE = soma_centre_value
135
157
 
136
- # getting kernel is not jitted
137
- with objmode(kernel=float_3d_type):
138
- kernel = get_kernel(ball_xy_size, ball_z_size)
139
- self.kernel = kernel
140
-
141
- self.overlap_threshold = np.sum(self.overlap_fraction * self.kernel)
158
+ # kernel comes in as XYZ, change to ZYX so it aligns with data
159
+ kernel = np.moveaxis(get_kernel(ball_xy_size, self.ball_z_size), 2, 0)
160
+ self.overlap_threshold = np.sum(self.overlap_fraction * kernel)
161
+ self.kernel_xy_size = kernel.shape[-2:]
162
+ self.kernel_z_size = self.ball_z_size
163
+
164
+ # convert to right type and pin for faster copying
165
+ kernel = torch.from_numpy(kernel).type(getattr(torch, dtype))
166
+ if torch_device != "cpu":
167
+ # torch at one point threw a cuda memory error when splitting cells
168
+ # on cpu due to pinning. It's best to only pin on using cuda
169
+ kernel.pin_memory()
170
+ # add 2 dimensions at the start so we have 11ZYX We need this shape in
171
+ # the conv step
172
+ self.kernel = (
173
+ kernel.unsqueeze(0)
174
+ .unsqueeze(0)
175
+ .to(device=torch_device, non_blocking=True)
176
+ )
142
177
 
143
- # Stores the current planes that are being filtered
144
- # first axis is z for faster rotating the z-axis
145
- self.volume = np.empty(
146
- (ball_z_size, plane_width, plane_height),
147
- dtype=np.uint32,
178
+ self.num_batches_before_ready = int(
179
+ math.ceil(self.ball_z_size / batch_size)
180
+ )
181
+ # Stores the current planes that are being filtered. Start with no data
182
+ self.volume = torch.empty(
183
+ (0, plane_height, plane_width),
184
+ dtype=getattr(torch, dtype),
148
185
  )
149
186
  # Index of the middle plane in the volume
150
- self.middle_z_idx = int(np.floor(ball_z_size / 2))
151
- self._num_z_added = 0
187
+ self.middle_z_idx = int(np.floor(self.ball_z_size / 2))
152
188
 
189
+ if not use_mask:
190
+ return
153
191
  # first axis is z
154
- self.inside_brain_tiles = np.empty(
192
+ n_vertical_tiles = int(np.ceil(plane_height / self.tile_step_height))
193
+ n_horizontal_tiles = int(np.ceil(plane_width / self.tile_step_width))
194
+ # Stores tile masks. We start with no data
195
+ self.inside_brain_tiles = torch.empty(
155
196
  (
156
- ball_z_size,
157
- int(np.ceil(plane_width / tile_step_width)),
158
- int(np.ceil(plane_height / tile_step_height)),
197
+ 0,
198
+ n_vertical_tiles,
199
+ n_horizontal_tiles,
159
200
  ),
160
- dtype=np.bool_,
201
+ dtype=torch.bool,
161
202
  )
162
203
 
163
204
  @property
164
- def ready(self) -> bool:
205
+ def first_valid_plane(self) -> int:
165
206
  """
166
- Return `True` if enough planes have been appended to run the filter.
207
+ The index in `self.volume` (or the planes passed in) that will be the
208
+ first plane returned from `get_processed_planes`.
209
+
210
+ E.g. if `ball_z_size` is 3, then this may return 1. Meaning the second
211
+ plane passed to `append` (index 1), will be the first returned plane
212
+ by `get_processed_planes`.
167
213
  """
168
- return self._num_z_added >= self.ball_z_size
214
+ return int(math.floor(self.ball_z_size / 2))
169
215
 
170
- def append(self, plane: np.ndarray, mask: np.ndarray) -> None:
216
+ @property
217
+ def remaining_planes(self) -> int:
171
218
  """
172
- Add a new 2D plane to the filter.
219
+ The number of planes in `self.volume` (or the planes passed in) that
220
+ will remain unprocessed after all the planes have been `walk`ed
221
+ and `get_processed_planes` called.
222
+
223
+ E.g. if `ball_z_size` is 3, then this may return 1. Meaning the last
224
+ plane passed to `append`, will never be returned by
225
+ `get_processed_planes` because the filter "center" never overlapped
226
+ with it.
173
227
  """
174
- if DEBUG:
175
- assert [e for e in plane.shape[:2]] == [
176
- e for e in self.volume.shape[1:]
177
- ], 'plane shape mismatch, expected "{}", got "{}"'.format(
178
- [e for e in self.volume.shape[1:]],
179
- [e for e in plane.shape[:2]],
180
- )
181
- assert [e for e in mask.shape[:2]] == [
182
- e for e in self.inside_brain_tiles.shape[1:]
183
- ], 'mask shape mismatch, expected"{}", got {}"'.format(
184
- [e for e in self.inside_brain_tiles.shape[1:]],
185
- [e for e in mask.shape[:2]],
186
- )
228
+ return self.ball_z_size - self.first_valid_plane - 1
187
229
 
188
- if self.ready:
189
- # Shift everything down by one to make way for the new plane
190
- # this is faster than np.roll, especially with z-axis first
191
- self.volume[:-1, :, :] = self.volume[1:, :, :]
192
- self.inside_brain_tiles[:-1, :, :] = self.inside_brain_tiles[
193
- 1:, :, :
194
- ]
230
+ @property
231
+ def ready(self) -> bool:
232
+ """
233
+ Return whether enough planes have been appended to run the filter
234
+ using `walk`.
235
+ """
236
+ return self.volume.shape[0] >= self.kernel_z_size
237
+
238
+ def append(
239
+ self, planes: torch.Tensor, masks: Optional[torch.Tensor] = None
240
+ ) -> None:
241
+ """
242
+ Add a new z-stack to the filter.
195
243
 
196
- # index for *next* slice is num we added *so far* until max
197
- idx = min(self._num_z_added, self.ball_z_size - 1)
198
- self._num_z_added += 1
244
+ Previous stacks passed to `append` are removed, except enough planes
245
+ at the top of the previous z-stack to provide padding so we can filter
246
+ starting from the first plane in `planes`. The first time we call
247
+ `append`, `first_valid_plane` is the first plane to actually be
248
+ filtered in the z-stack due to lack of padding.
199
249
 
200
- # Add the new plane to the top of volume and inside_brain_tiles
201
- self.volume[idx, :, :] = plane
202
- self.inside_brain_tiles[idx, :, :] = mask
250
+ So make sure to call `walk`/`get_processed_planes` before calling
251
+ `append` again.
203
252
 
204
- def get_middle_plane(self) -> np.ndarray:
205
- """
206
- Get the plane in the middle of self.volume.
253
+ Parameters
254
+ ----------
255
+ planes : torch.Tensor
256
+ The z-stack. There can be one or more planes in the stack, but it
257
+ must have 3 dimensions. Input data is not modified.
258
+ masks : torch.Tensor
259
+ A z-stack tile mask, indicating for each tile whether it's in or
260
+ outside the brain. If the latter it's excluded.
261
+
262
+ If `use_mask` was True, this must be provided. If False, this
263
+ parameter will be ignored.
264
+
265
+ Input data is not modified.
207
266
  """
208
- return self.volume[self.middle_z_idx, :, :].copy()
209
-
210
- def walk(self, parallel: bool = False) -> None:
211
- # **don't** pass parallel as keyword arg - numba struggles with it
212
- # Highly optimised because most time critical
213
- ball_radius = self.ball_xy_size // 2
214
- # Get extents of image that are covered by tiles
215
- tile_mask_covered_img_width = (
216
- self.inside_brain_tiles.shape[1] * self.tile_step_width
217
- )
218
- tile_mask_covered_img_height = (
219
- self.inside_brain_tiles.shape[2] * self.tile_step_height
220
- )
221
- # Get maximum offsets for the ball
222
- max_width = tile_mask_covered_img_width - self.ball_xy_size
223
- max_height = tile_mask_covered_img_height - self.ball_xy_size
224
-
225
- # we have to pass the raw volume so walk doesn't use its edits as it
226
- # processes the volume. self.volume is the one edited in place
227
- input_volume = self.volume.copy()
228
-
229
- if parallel:
230
- _walk_parallel(
231
- max_height,
232
- max_width,
233
- self.tile_step_width,
234
- self.tile_step_height,
235
- self.inside_brain_tiles,
236
- input_volume,
237
- self.volume,
238
- self.kernel,
239
- ball_radius,
240
- self.middle_z_idx,
241
- self.overlap_threshold,
242
- self.THRESHOLD_VALUE,
243
- self.SOMA_CENTRE_VALUE,
267
+ if self.volume.shape[0]:
268
+ if self.volume.shape[0] < self.kernel_z_size:
269
+ num_remaining_with_padding = 0
270
+ else:
271
+ num_remaining = self.kernel_z_size - (self.middle_z_idx + 1)
272
+ num_remaining_with_padding = num_remaining + self.middle_z_idx
273
+
274
+ self.volume = torch.cat(
275
+ [self.volume[-num_remaining_with_padding:, :, :], planes],
276
+ dim=0,
244
277
  )
278
+
279
+ if self.inside_brain_tiles is not None:
280
+ self.inside_brain_tiles = torch.cat(
281
+ [
282
+ self.inside_brain_tiles[
283
+ -num_remaining_with_padding:, :, :
284
+ ],
285
+ masks,
286
+ ],
287
+ dim=0,
288
+ )
245
289
  else:
246
- _walk_single(
247
- max_height,
248
- max_width,
249
- self.tile_step_width,
250
- self.tile_step_height,
251
- self.inside_brain_tiles,
252
- input_volume,
253
- self.volume,
254
- self.kernel,
255
- ball_radius,
256
- self.middle_z_idx,
257
- self.overlap_threshold,
258
- self.THRESHOLD_VALUE,
259
- self.SOMA_CENTRE_VALUE,
260
- )
290
+ self.volume = planes.clone()
291
+ if self.inside_brain_tiles is not None:
292
+ self.inside_brain_tiles = masks.clone()
261
293
 
294
+ def get_processed_planes(self) -> np.ndarray:
295
+ """
296
+ After passing enough planes to `append`, and after `walk`, this returns
297
+ a copy of the processed planes as a numpy z-stack.
262
298
 
263
- @njit(cache=True)
264
- def _cube_overlaps(
265
- volume: np.ndarray,
266
- x_start: int,
267
- x_end: int,
268
- y_start: int,
269
- y_end: int,
270
- overlap_threshold: float,
271
- threshold_value: int,
272
- kernel: np.ndarray,
273
- ) -> bool: # Highly optimised because most time critical
274
- """
275
- For each pixel in cube in volume that is greater than THRESHOLD_VALUE, sum
276
- up the corresponding pixels in *kernel*. If the total is less than
277
- overlap_threshold, return False, otherwise return True.
299
+ It only starts returning planes corresponding to plane
300
+ `first_valid_plane` relative to the first planes passed to `append`.
301
+ E.g. if `ball_z_size` is 3 and `first_valid_plane` is 1, and you passed
302
+ 5 planes total to `append`, then this will have returned planes [1, 3].
278
303
 
279
- Halfway through scanning the z-planes, if the total overlap is
280
- less than 0.4 * overlap_threshold, this will return False early
281
- without scanning the second half of the z-planes.
304
+ Notice the last plane was not included, because we return only "middle"
305
+ planes - planes that can correspond to the center of a ball.
306
+ """
307
+ if not self.ready:
308
+ raise TypeError("Not enough planes were appended")
309
+
310
+ num_processed = self.volume.shape[0] - self.kernel_z_size + 1
311
+ assert num_processed
312
+ middle = self.middle_z_idx
313
+ planes = (
314
+ self.volume[middle : middle + num_processed, :, :]
315
+ .cpu()
316
+ .numpy()
317
+ .copy()
318
+ )
319
+ return planes
282
320
 
283
- Parameters
284
- ----------
285
- volume :
286
- 3D array.
287
- x_start, x_end, y_start, y_end :
288
- The start and end indices in volume that form the cube. End is
289
- exclusive
290
- overlap_threshold :
291
- Threshold above which to return True.
292
- threshold_value :
293
- Value above which a pixel is marked as being part of a cell.
294
- kernel :
295
- 3D array, with the same shape as *cube* in the volume.
296
- """
297
- current_overlap_value = 0.0
298
-
299
- middle = np.floor(volume.shape[0] / 2) + 1
300
- halfway_overlap_thresh = (
301
- overlap_threshold * 0.4
302
- ) # FIXME: do not hard code value
303
-
304
- for z in range(volume.shape[0]):
305
- # TODO: OPTIMISE: step from middle to outer boundaries to check
306
- # more data first
307
- #
308
- # If halfway through the array, and the overlap value isn't more than
309
- # 0.4 * the overlap threshold, return
310
- if z == middle and current_overlap_value < halfway_overlap_thresh:
311
- return False # DEBUG: optimisation attempt
312
-
313
- for y in range(y_start, y_end):
314
- for x in range(x_start, x_end):
315
- # includes self.SOMA_CENTRE_VALUE
316
- if volume[z, x, y] >= threshold_value:
317
- # x/y must be shifted in kernel because we x/y is relative
318
- # to the full volume, so shift it to relative to the cube
319
- current_overlap_value += kernel[
320
- x - x_start, y - y_start, z
321
- ]
322
-
323
- return current_overlap_value > overlap_threshold
324
-
325
-
326
- @njit
327
- def _is_tile_to_check(
328
- x: int,
329
- y: int,
330
- middle_z: int,
331
- tile_step_width: int,
332
- tile_step_height: int,
333
- inside_brain_tiles: np.ndarray,
334
- ) -> bool: # Highly optimised because most time critical
335
- """
336
- Check if the tile containing pixel (x, y) is a tile that needs checking.
337
- """
338
- x_in_mask = x // tile_step_width # TEST: test bounds (-1 range)
339
- y_in_mask = y // tile_step_height # TEST: test bounds (-1 range)
340
- return inside_brain_tiles[middle_z, x_in_mask, y_in_mask]
321
+ def walk(self) -> None:
322
+ """
323
+ Applies the filter to all the planes passed to `append`.
341
324
 
325
+ May only be called if `ready` was True.
342
326
 
343
- def _walk_base(
344
- max_height: int,
345
- max_width: int,
346
- tile_step_width: int,
327
+ You can get the processed planes from `get_processed_planes`.
328
+ """
329
+ if not self.ready:
330
+ raise TypeError("Called walk before enough planes were appended")
331
+
332
+ _walk(
333
+ self.kernel_z_size,
334
+ self.middle_z_idx,
335
+ self.tile_step_height,
336
+ self.tile_step_width,
337
+ self.overlap_threshold,
338
+ self.THRESHOLD_VALUE,
339
+ self.SOMA_CENTRE_VALUE,
340
+ self.kernel,
341
+ self.volume,
342
+ self.inside_brain_tiles,
343
+ )
344
+
345
+
346
+ @torch.jit.script
347
+ def _walk(
348
+ kernel_z_size: int,
349
+ middle: int,
347
350
  tile_step_height: int,
348
- inside_brain_tiles: np.ndarray,
349
- input_volume: np.ndarray,
350
- volume: np.ndarray,
351
- kernel: np.ndarray,
352
- ball_radius: int,
353
- middle_z: int,
351
+ tile_step_width: int,
354
352
  overlap_threshold: float,
355
353
  threshold_value: int,
356
354
  soma_centre_value: int,
357
- ) -> None:
358
- """
359
- Scan through *volume*, and mark pixels where there are enough surrounding
360
- pixels with high enough intensity.
355
+ kernel: torch.Tensor,
356
+ volume: torch.Tensor,
357
+ inside_brain_tiles: Optional[torch.Tensor],
358
+ ):
359
+ num_process = volume.shape[0] - kernel_z_size + 1
360
+ height, width = volume.shape[1:]
361
+ num_z = kernel.shape[2]
362
+
363
+ # threshold volume so it's zero/one. And add two dims at start so
364
+ # it's 11ZYX
365
+ volume_tresh = (
366
+ (volume >= threshold_value)
367
+ .unsqueeze(0)
368
+ .unsqueeze(0)
369
+ .type(kernel.dtype)
370
+ )
361
371
 
362
- The surrounding area is defined by the *kernel*.
372
+ # we do a plane at a time, volume: i:i+num_z, for plane i+middle
373
+ for i in range(num_process):
374
+ # spherical kernel is symmetric so convolution=correlation. Use
375
+ # binary threshold mask over the kernel to sum the value of the
376
+ # kernel at voxels that are bright
377
+ overlap = F.conv3d(
378
+ volume_tresh[:, :, i : i + num_z, :, :],
379
+ kernel,
380
+ stride=1,
381
+ padding="valid",
382
+ )[0, 0, 0, :, :]
383
+ overlaps = overlap > overlap_threshold
384
+
385
+ # only edit the volume that is valid - conv excludes edges so we
386
+ # only edit the plane parts returned by conv3d
387
+ height_valid, width_valid = overlaps.shape
388
+ height_offset = (height - height_valid) // 2
389
+ width_offset = (width - width_valid) // 2
390
+ sub_volume = volume[
391
+ i + middle,
392
+ height_offset : height_offset + height_valid,
393
+ width_offset : width_offset + width_valid,
394
+ ]
395
+
396
+ # do we use tile masks or just overlapping?
397
+ if inside_brain_tiles is not None:
398
+ # unfold tiles to cover the full voxel area each tile covers
399
+ inside = (
400
+ inside_brain_tiles[i + middle, :, :]
401
+ .repeat_interleave(tile_step_height, dim=0)
402
+ .repeat_interleave(tile_step_width, dim=1)
403
+ )
404
+ # again only process pixels in the valid area
405
+ inside = inside[
406
+ height_offset : height_offset + height_valid,
407
+ width_offset : width_offset + width_valid,
408
+ ]
363
409
 
364
- Parameters
365
- ----------
366
- max_height, max_width :
367
- Maximum offsets for the ball filter.
368
- inside_brain_tiles :
369
- 3d array containing information on whether a tile is
370
- inside the brain or not. Tiles outside the brain are skipped.
371
- input_volume :
372
- 3D array containing the plane-filtered data passed to the function
373
- before walking. volume is edited in place, so this is the original
374
- volume to prevent the changes for some cubes affective other cubes
375
- during a single walk call.
376
- volume :
377
- 3D array containing the plane-filtered data - edited in place.
378
- kernel :
379
- 3D array
380
- ball_radius :
381
- Radius of the ball in the xy plane.
382
- soma_centre_value :
383
- Value that is used to mark pixels in *volume*.
384
-
385
- Notes
386
- -----
387
- Warning: modifies volume in place!
388
- """
389
- for y in prange(max_height):
390
- for x in prange(max_width):
391
- ball_centre_x = x + ball_radius
392
- ball_centre_y = y + ball_radius
393
- if _is_tile_to_check(
394
- ball_centre_x,
395
- ball_centre_y,
396
- middle_z,
397
- tile_step_width,
398
- tile_step_height,
399
- inside_brain_tiles,
400
- ):
401
- if _cube_overlaps(
402
- input_volume,
403
- x,
404
- x + kernel.shape[0],
405
- y,
406
- y + kernel.shape[1],
407
- overlap_threshold,
408
- threshold_value,
409
- kernel,
410
- ):
411
- volume[middle_z, ball_centre_x, ball_centre_y] = (
412
- soma_centre_value
413
- )
414
-
415
-
416
- _walk_parallel = njit(parallel=True)(_walk_base)
417
- _walk_single = njit(parallel=False)(_walk_base)
410
+ # must have enough ball overlap to be bright/tile is in brain
411
+ sub_volume[torch.logical_and(overlaps, inside)] = soma_centre_value
412
+
413
+ else:
414
+ # must have enough ball overlap to be bright
415
+ sub_volume[overlaps] = soma_centre_value