cellfinder 1.3.2__py3-none-any.whl → 1.4.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,8 @@
1
1
  import math
2
- from typing import Generator, Tuple
2
+ from typing import Tuple
3
3
 
4
- import numpy as np
5
- from numba import njit
4
+ import torch
5
+ import torch.nn.functional as F
6
6
 
7
7
 
8
8
  class TileWalker:
@@ -15,74 +15,140 @@ class TileWalker:
15
15
  The mean and standard deviation of this tile is calculated, and
16
16
  the threshold set at 1 + mean + (2 * stddev).
17
17
 
18
- Attributes
18
+ Parameters
19
19
  ----------
20
- bright_tiles_mask :
21
- An boolean array whose entries correspond to whether each tile is
22
- bright (1) or dark (0). The values are set in
23
- self.mark_bright_tiles().
20
+ plane_shape : tuple(int, int)
21
+ Height/width of the planes.
22
+ soma_diameter : float
23
+ Diameter of the soma in voxels.
24
24
  """
25
25
 
26
- def __init__(self, img: np.ndarray, soma_diameter: int) -> None:
27
- self.img = img
28
- self.img_width, self.img_height = img.shape
29
- self.tile_width = soma_diameter * 2
30
- self.tile_height = soma_diameter * 2
26
+ def __init__(
27
+ self, plane_shape: Tuple[int, int], soma_diameter: int
28
+ ) -> None:
31
29
 
32
- n_tiles_width = math.ceil(self.img_width / self.tile_width)
33
- n_tiles_height = math.ceil(self.img_height / self.tile_height)
34
- self.bright_tiles_mask = np.zeros(
35
- (n_tiles_width, n_tiles_height), dtype=bool
36
- )
30
+ self.img_height, self.img_width = plane_shape
31
+ self.tile_height = soma_diameter * 2
32
+ self.tile_width = soma_diameter * 2
37
33
 
38
- corner_tile = img[0 : self.tile_width, 0 : self.tile_height]
39
- corner_intensity = np.mean(corner_tile)
40
- corner_sd = np.std(corner_tile)
41
- # add 1 to ensure not 0, as disables
42
- self.out_of_brain_threshold = (corner_intensity + (2 * corner_sd)) + 1
34
+ self.n_tiles_height = math.ceil(self.img_height / self.tile_height)
35
+ self.n_tiles_width = math.ceil(self.img_width / self.tile_width)
43
36
 
44
- def _get_tiles(self) -> Generator[Tuple[int, int, np.ndarray], None, None]:
37
+ def get_bright_tiles(self, planes: torch.Tensor) -> torch.Tensor:
45
38
  """
46
- Generator that yields tiles of the 2D image.
39
+ Takes a 3d z-stack. For each z it computes the mean/std of the corner
40
+ tile and uses that to get a in/out of brain threshold for each z.
47
41
 
48
- Notes
49
- -----
50
- The final tile in each dimension can have a smaller size than the
51
- rest of the tiles if the tile shape does not exactly divide the
52
- image shape.
53
- """
54
- for y in range(
55
- 0, self.img_height - self.tile_height, self.tile_height
56
- ):
57
- for x in range(
58
- 0, self.img_width - self.tile_width, self.tile_width
59
- ):
60
- tile = self.img[
61
- x : x + self.tile_width, y : y + self.tile_height
62
- ]
63
- yield x, y, tile
64
-
65
- def mark_bright_tiles(self) -> None:
66
- """
67
- Loop through tiles, and if the average value of a tile is
68
- greater than the intensity threshold mark the tile as bright
69
- in self.bright_tiles_mask.
42
+ Parameters
43
+ ----------
44
+ planes : torch.Tensor
45
+ 3d z-stack.
46
+
47
+ Returns
48
+ -------
49
+ out_of_brain_thresholds : torch.Tensor
50
+ 3d z-stack whose planar shape is the number of tiles in a plane.
51
+ The returned data will be on the same torch device as the input
52
+ planes.
70
53
  """
71
- threshold = self.out_of_brain_threshold
72
- if threshold == 0:
73
- return
54
+ return _get_bright_tiles(
55
+ planes,
56
+ self.n_tiles_height,
57
+ self.n_tiles_width,
58
+ self.tile_height,
59
+ self.tile_width,
60
+ )
74
61
 
75
- for x, y, tile in self._get_tiles():
76
- if not is_low_average(tile, threshold):
77
- mask_x = x // self.tile_width
78
- mask_y = y // self.tile_height
79
- self.bright_tiles_mask[mask_x, mask_y] = True
62
+ def get_tiled_buffer(self, depth: int, device: str):
63
+ return torch.zeros(
64
+ (depth, self.n_tiles_height, self.n_tiles_width),
65
+ dtype=torch.bool,
66
+ device=device,
67
+ )
80
68
 
81
69
 
82
- @njit
83
- def is_low_average(tile: np.ndarray, threshold: float) -> bool:
70
+ @torch.jit.script
71
+ def _get_out_of_brain_threshold(
72
+ planes: torch.Tensor, tile_height: int, tile_width: int
73
+ ) -> torch.Tensor:
84
74
  """
85
- Return `True` if the average value of *tile* is below *threshold*.
75
+ Takes a 3d z-stack. For each z it computes the mean/std of the corner tile
76
+ and uses that to get a in/out of brain threshold for each z-stack.
77
+
78
+ Parameters
79
+ ----------
80
+ planes :
81
+ 3d z-stack.
82
+ tile_height :
83
+ Height of each tile.
84
+ tile_width :
85
+ Width of each tile.
86
+
87
+ Returns
88
+ -------
89
+ out_of_brain_thresholds :
90
+ 1d z-stack.
91
+ """
92
+ # get corner tile
93
+ corner_tiles = planes[:, 0:tile_height, 0:tile_width]
94
+ # convert from ZYX -> ZK, where K is the elements in the corner tile
95
+ corner_tiles = corner_tiles.reshape((planes.shape[0], -1))
96
+
97
+ # need to operate in float64, in case the values are large
98
+ corner64 = corner_tiles.type(torch.float64)
99
+ corner_intensity = torch.mean(corner64, dim=1).type(planes.dtype)
100
+ # for parity with past when we used np.std, which defaults to ddof=0
101
+ corner_sd = torch.std(corner64, dim=1, correction=0).type(planes.dtype)
102
+ # add 1 to ensure not 0, as disables
103
+ out_of_brain_thresholds = corner_intensity + 2 * corner_sd + 1
104
+
105
+ return out_of_brain_thresholds
106
+
107
+
108
+ @torch.jit.script
109
+ def _get_bright_tiles(
110
+ planes: torch.Tensor,
111
+ n_tiles_height: int,
112
+ n_tiles_width: int,
113
+ tile_height: int,
114
+ tile_width: int,
115
+ ) -> torch.Tensor:
86
116
  """
87
- avg = np.mean(tile)
88
- return avg < threshold
117
+ Loop through the tiles of the plane for each plane. And if the average
118
+ value of a tile is greater than the intensity threshold of that plain,
119
+ mark the tile as bright.
120
+ """
121
+ bright_tiles_mask = torch.zeros(
122
+ (planes.shape[0], n_tiles_height, n_tiles_width),
123
+ dtype=torch.bool,
124
+ device=planes.device,
125
+ )
126
+ # if we don't have enough size for a single tile, it's all outside
127
+ if planes.shape[1] < tile_height or planes.shape[2] < tile_width:
128
+ return bright_tiles_mask
129
+
130
+ # for each plane, the threshold
131
+ out_of_brain_thresholds = _get_out_of_brain_threshold(
132
+ planes, tile_height, tile_width
133
+ )
134
+ # thresholds Z -> ZYX shape
135
+ thresholds = out_of_brain_thresholds.view(-1, 1, 1)
136
+
137
+ # ZYX -> ZCYX required for function (C=1)
138
+ planes = planes.unsqueeze(1)
139
+ # get the average of each tile
140
+ tile_avg = F.avg_pool2d(
141
+ planes,
142
+ (tile_height, tile_width),
143
+ ceil_mode=False, # default is False, but to make sure
144
+ )
145
+ # go back from ZCYX -> ZYX
146
+ tile_avg = tile_avg[:, 0, :, :]
147
+
148
+ bright = tile_avg >= thresholds
149
+ # tile_avg and bright may be smaller than bright_tiles_mask because
150
+ # avg_pool2d first subtracts the kernel size before computing # tiles.
151
+ # So contain view to that size
152
+ bright_tiles_mask[:, : bright.shape[1], : bright.shape[2]][bright] = True
153
+
154
+ return bright_tiles_mask