cellfinder 1.3.3__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cellfinder might be problematic. Click here for more details.

@@ -1,70 +1,427 @@
1
+ """
2
+ Container for all the settings used during 2d/3d filtering and cell detection.
3
+ """
4
+
1
5
  import math
2
- from typing import Tuple
6
+ from dataclasses import dataclass
7
+ from functools import cached_property
8
+ from typing import Callable, Optional, Tuple, Type
3
9
 
4
10
  import numpy as np
11
+ from brainglobe_utils.general.system import get_num_processes
5
12
 
6
- from cellfinder.core.detect.filters.volume.ball_filter import BallFilter
7
- from cellfinder.core.detect.filters.volume.structure_detection import (
8
- CellDetector,
13
+ from cellfinder.core.tools.tools import (
14
+ get_data_converter,
15
+ get_max_possible_int_value,
9
16
  )
10
- from cellfinder.core.tools.tools import get_max_possible_value
11
-
12
-
13
- def get_ball_filter(
14
- *,
15
- plane: np.ndarray,
16
- soma_diameter: int,
17
- ball_xy_size: int,
18
- ball_z_size: int,
19
- ball_overlap_fraction: float = 0.6,
20
- ) -> BallFilter:
21
- # thrsh_val is used to clip the data in plane to make sure
22
- # a number is available to mark cells. soma_centre_val is the
23
- # number used to mark cells.
24
- max_value = get_max_possible_value(plane)
25
- thrsh_val = max_value - 1
26
- soma_centre_val = max_value
27
-
28
- tile_width = soma_diameter * 2
29
- plane_height, plane_width = plane.shape
30
-
31
- ball_filter = BallFilter(
32
- plane_width,
33
- plane_height,
34
- ball_xy_size,
35
- ball_z_size,
36
- overlap_fraction=ball_overlap_fraction,
37
- tile_step_width=tile_width,
38
- tile_step_height=tile_width,
39
- threshold_value=thrsh_val,
40
- soma_centre_value=soma_centre_val,
41
- )
42
- return ball_filter
43
-
44
-
45
- def get_cell_detector(
46
- *, plane_shape: Tuple[int, int], ball_z_size: int, z_offset: int = 0
47
- ) -> CellDetector:
48
- plane_height, plane_width = plane_shape
49
- start_z = z_offset + int(math.floor(ball_z_size / 2))
50
- return CellDetector(plane_width, plane_height, start_z=start_z)
51
-
52
-
53
- def setup_tile_filtering(plane: np.ndarray) -> Tuple[int, int]:
54
- """
55
- Setup values that are used to threshold the plane during 2D filtering.
56
-
57
- Returns
58
- -------
59
- clipping_value :
60
- Upper value used to clip planes before 2D filtering. This is chosen
61
- to leave two numbers left that can later be used to mark bright points
62
- during the 2D and 3D filtering stages.
63
- threshold_value :
64
- Value used to mark bright pixels after 2D filtering.
65
- """
66
- max_value = get_max_possible_value(plane)
67
- clipping_value = max_value - 2
68
- thrsh_val = max_value - 1
69
-
70
- return clipping_value, thrsh_val
17
+
18
+ MAX_TORCH_COMP_THREADS = 12
19
+ # As seen in the benchmarks in the original PR, when running on CPU using
20
+ # more than ~12 cores it starts to result in slowdowns. So limit to this
21
+ # many cores when doing computational work (e.g. torch.functional.Conv2D).
22
+ #
23
+ # This prevents thread contention.
24
+
25
+
26
+ @dataclass
27
+ class DetectionSettings:
28
+ """
29
+ Configuration class with all the parameters used during 2d and 3d filtering
30
+ and structure splitting.
31
+ """
32
+
33
+ plane_original_np_dtype: Type[np.number] = np.uint16
34
+ """
35
+ The numpy data type of the input data that will be passed to the filtering
36
+ pipeline.
37
+
38
+ Throughout filtering at key stages, the data range is kept such
39
+ that we can convert the data back to this data type without having to
40
+ scale. I.e. the min/max of the data fits in this data type.
41
+
42
+ Except for the cell detection stage, in that stage the data range can be
43
+ larger because the values are cell IDs and not intensity data anymore.
44
+
45
+ During structure splitting, we do just 3d filtering/cell detection. This is
46
+ again the data type used as input to the filtering.
47
+
48
+ Defaults to `uint16`
49
+ """
50
+
51
+ detection_dtype: Type[np.number] = np.uint64
52
+ """
53
+ The numpy data type that the cell detection code expects our filtered
54
+ data to be in.
55
+
56
+ After filtering, where the voxels are intensity values, we pass the data
57
+ to cell detection where the voxels turn into cell IDs. So the data type
58
+ needs to be large enough to support the number of cells in the data.
59
+
60
+ To get the data from the filtering data type to the detection data type
61
+ use `detection_data_converter_func`.
62
+
63
+ Defaults to `uint64`.
64
+ """
65
+
66
+ plane_shape: Tuple[int, int] = (1, 1)
67
+ """
68
+ The shape of each plane of the input data as (height, width) - i.e.
69
+ (axis 1, axis 2) in the z-stack where z is the first axis.
70
+ """
71
+
72
+ start_plane: int = 0
73
+ """The index of first plane to process, in the input data (inclusive)."""
74
+
75
+ end_plane: int = 1
76
+ """
77
+ The index of the last plane at which to stop processing the input data
78
+ (not inclusive).
79
+ """
80
+
81
+ voxel_sizes: Tuple[float, float, float] = (1.0, 1.0, 1.0)
82
+ """
83
+ Tuple of voxel sizes in each dimension (z, y, x). We use this to convert
84
+ from `um` to pixel sizes.
85
+ """
86
+
87
+ soma_spread_factor: float = 1.4
88
+ """Spread factor for soma size - how much it may stretch in the images."""
89
+
90
+ soma_diameter_um: float = 16
91
+ """
92
+ Diameter of a typical soma in um. Bright areas larger than this will be
93
+ split.
94
+ """
95
+
96
+ max_cluster_size_um3: float = 100_000
97
+ """
98
+ Maximum size of a cluster (bright area) that will be processed, in um.
99
+ Larger bright areas are skipped as artifacts.
100
+ """
101
+
102
+ ball_xy_size_um: float = 6
103
+ """
104
+ Diameter of the 3d spherical kernel filter in the x/y dimensions in um.
105
+ See `ball_xy_size` for size in voxels.
106
+ """
107
+
108
+ ball_z_size_um: float = 15
109
+ """
110
+ Diameter of the 3d spherical kernel filter in the z dimension in um.
111
+ See `ball_z_size` for size in voxels.
112
+
113
+ `ball_z_size` also determines to the minimum number of planes that are
114
+ stacked before can filter the central plane of the stack.
115
+ """
116
+
117
+ ball_overlap_fraction: float = 0.6
118
+ """
119
+ Fraction of overlap between a bright area and the spherical kernel,
120
+ for the area to be considered a single ball.
121
+ """
122
+
123
+ log_sigma_size: float = 0.2
124
+ """Size of the sigma for the 2d Gaussian filter."""
125
+
126
+ n_sds_above_mean_thresh: float = 10
127
+ """
128
+ Number of standard deviations above the mean intensity to use for a
129
+ threshold to define bright areas. Below it, it's not considered bright.
130
+ """
131
+
132
+ outlier_keep: bool = False
133
+ """Whether to keep outlier structures during detection."""
134
+
135
+ artifact_keep: bool = False
136
+ """Whether to keep artifact structures during detection."""
137
+
138
+ save_planes: bool = False
139
+ """
140
+ Whether to save the 2d/3d filtered planes during after filtering.
141
+
142
+ It is saved as tiffs of data type `plane_original_np_dtype`.
143
+ """
144
+
145
+ plane_directory: Optional[str] = None
146
+ """Directory path where to save the planes, if saving."""
147
+
148
+ batch_size: int = 1
149
+ """
150
+ The number of planes to process in each batch of the 2d/3d filters.
151
+
152
+ For CPU, each plane in a batch is 2d filtered (the slowest filters) in its
153
+ own sub-process. But 3d filtering happens in a single thread. So larger
154
+ batches will use more processes but can speed up filtering until IO/3d
155
+ filters become the bottleneck.
156
+
157
+ For CUDA, 2d and 3d filtering happens on the GPU and the larger the batch
158
+ size, the better the performance. Until it fills up the GPU memory - after
159
+ which it becomes slower.
160
+
161
+ In all cases, higher batch size means more RAM used.
162
+ """
163
+
164
+ num_prefetch_batches: int = 2
165
+ """
166
+ The number of batches to load into memory.
167
+
168
+ This many batches are loaded in memory so the next batch is ready to be
169
+ sent to the filters as soon as the previous batch is done.
170
+
171
+ The higher the number the more RAM used, but it can also speed up
172
+ processing if IO becomes a limiting factor.
173
+ """
174
+
175
+ torch_device: str = "cpu"
176
+ """
177
+ The device on which to run the 2d and/or 3d filtering.
178
+
179
+ Either `"cpu"` or PyTorch's GPU device name, such as `"cuda"` or `"cuda:0"`
180
+ to run on the first GPU.
181
+ """
182
+
183
+ n_free_cpus: int = 2
184
+ """
185
+ Number of free CPU cores to keep available and not use during parallel
186
+ processing. Internally, more cores may actually be used by the system,
187
+ which we don't control.
188
+ """
189
+
190
+ n_splitting_iter: int = 10
191
+ """
192
+ During the structure splitting phase we iteratively shrink the bright areas
193
+ and re-filter with the 3d filter. This is the number of iterations to do.
194
+
195
+ This is a maximum because we also stop if there are no more structures left
196
+ during any iteration.
197
+ """
198
+
199
+ def __getstate__(self):
200
+ d = self.__dict__.copy()
201
+ # when sending across processes, we need to be able to pickle. This
202
+ # property cannot be pickled (and doesn't need to be)
203
+ if "filter_data_converter_func" in d:
204
+ del d["filter_data_converter_func"]
205
+ return d
206
+
207
+ @cached_property
208
+ def filter_data_converter_func(self) -> Callable[[np.ndarray], np.ndarray]:
209
+ """
210
+ A callable that takes a numpy array of type
211
+ `plane_original_np_dtype` and converts it into the `filtering_dtype`
212
+ type.
213
+
214
+ We use this to convert the input data into the data type used for
215
+ filtering.
216
+ """
217
+ return get_data_converter(
218
+ self.plane_original_np_dtype, self.filtering_dtype
219
+ )
220
+
221
+ @cached_property
222
+ def filtering_dtype(self) -> Type[np.floating]:
223
+ """
224
+ The numpy data type that the 2d/3d filters expect our data to be in.
225
+ Use `filter_data_converter_func` to convert to this type.
226
+
227
+ The data will be used in the form of torch tensors, but it'll be this
228
+ data type.
229
+
230
+ Currently, it's either float32 or float64.
231
+ """
232
+ original_dtype = self.plane_original_np_dtype
233
+ original_max_int = get_max_possible_int_value(original_dtype)
234
+
235
+ # does original data fit in float32
236
+ if original_max_int <= get_max_possible_int_value(np.float32):
237
+ return np.float32
238
+ # what about float64
239
+ if original_max_int <= get_max_possible_int_value(np.float64):
240
+ return np.float64
241
+ raise TypeError("Input array data type is too big for a float64")
242
+
243
+ @cached_property
244
+ def clipping_value(self) -> int:
245
+ """
246
+ The maximum value used to clip the input to, as well as the value to
247
+ which the filtered data is scaled to during filtering.
248
+
249
+ This ensures the filtered data fits in the `plane_original_np_dtype`.
250
+ """
251
+ return get_max_possible_int_value(self.plane_original_np_dtype) - 2
252
+
253
+ @cached_property
254
+ def threshold_value(self) -> int:
255
+ """
256
+ The value used to set bright areas as indicating it's above a
257
+ brightness threshold, during 2d filtering.
258
+ """
259
+ return get_max_possible_int_value(self.plane_original_np_dtype) - 1
260
+
261
+ @cached_property
262
+ def soma_centre_value(self) -> int:
263
+ """
264
+ The value used to mark bright areas as the location of a soma center,
265
+ during 3d filtering.
266
+ """
267
+ return get_max_possible_int_value(self.plane_original_np_dtype)
268
+
269
+ @cached_property
270
+ def detection_soma_centre_value(self) -> int:
271
+ """
272
+ The value used to mark bright areas as the location of a soma center,
273
+ during detection. See `detection_data_converter_func`.
274
+ """
275
+ return get_max_possible_int_value(self.detection_dtype)
276
+
277
+ @cached_property
278
+ def detection_data_converter_func(
279
+ self,
280
+ ) -> Callable[[np.ndarray], np.ndarray]:
281
+ """
282
+ A callable that takes a numpy array of type
283
+ `filtering_dtype` and converts it into the `detection_dtype`
284
+ type.
285
+
286
+ It takes the filtered data where somas are marked with the
287
+ `soma_centre_value` and returns a volume of the same size where the
288
+ somas are marked with `detection_soma_centre_value`. Other voxels are
289
+ zeroed.
290
+
291
+ We use this to convert the output of the 3d filter into the data
292
+ passed to cell detection.
293
+ """
294
+
295
+ def convert_for_cell_detection(data: np.ndarray) -> np.ndarray:
296
+ detection_data = np.zeros_like(data, dtype=self.detection_dtype)
297
+ detection_data[data == self.soma_centre_value] = (
298
+ self.detection_soma_centre_value
299
+ )
300
+ return detection_data
301
+
302
+ return convert_for_cell_detection
303
+
304
+ @property
305
+ def tile_height(self) -> int:
306
+ """
307
+ The height of each tile of the tiled input image, used during filtering
308
+ to mark individual tiles as inside/outside the brain.
309
+ """
310
+ return self.soma_diameter * 2
311
+
312
+ @property
313
+ def tile_width(self) -> int:
314
+ """
315
+ The width of each tile of the tiled input image, used during filtering
316
+ to mark individual tiles as inside/outside the brain.
317
+ """
318
+ return self.soma_diameter * 2
319
+
320
+ @property
321
+ def plane_height(self) -> int:
322
+ """The height of each input plane of the z-stack."""
323
+ return self.plane_shape[0]
324
+
325
+ @property
326
+ def plane_width(self) -> int:
327
+ """The width of each input plane of the z-stack."""
328
+ return self.plane_shape[1]
329
+
330
+ @property
331
+ def n_planes(self) -> int:
332
+ """The number of planes in the z-stack."""
333
+ return self.end_plane - self.start_plane
334
+
335
+ @property
336
+ def n_processes(self) -> int:
337
+ """The maximum number of process we can use during detection."""
338
+ n = get_num_processes(min_free_cpu_cores=self.n_free_cpus)
339
+ return max(n - 1, 1)
340
+
341
+ @property
342
+ def n_torch_comp_threads(self) -> int:
343
+ """
344
+ The maximum number of process we should use during filtering,
345
+ using pytorch.
346
+
347
+ This is less than `n_processes` because we account for thread
348
+ contention. Specifically it's limited by `MAX_TORCH_COMP_THREADS`.
349
+ """
350
+ # Reserve batch_size cores for batch multiprocess parallelization on
351
+ # CPU, 1 per plane. for GPU it doesn't matter either way because it
352
+ # doesn't use threads. Also reserve for data feeding thread and
353
+ # cell detection. Don't let it go below 4.
354
+ n = max(4, self.n_processes - self.batch_size - 2)
355
+ n = min(n, self.n_processes)
356
+ return min(n, MAX_TORCH_COMP_THREADS)
357
+
358
+ @property
359
+ def in_plane_pixel_size(self) -> float:
360
+ """Returns the average in-plane (xy) um/pixel."""
361
+ voxel_sizes = self.voxel_sizes
362
+ return (voxel_sizes[2] + voxel_sizes[1]) / 2
363
+
364
+ @cached_property
365
+ def soma_diameter(self) -> int:
366
+ """The `soma_diameter_um`, but in voxels."""
367
+ return int(round(self.soma_diameter_um / self.in_plane_pixel_size))
368
+
369
+ @cached_property
370
+ def max_cluster_size(self) -> int:
371
+ """The `max_cluster_size_um3`, but in voxels."""
372
+ voxel_sizes = self.voxel_sizes
373
+ voxel_volume = (
374
+ float(voxel_sizes[2])
375
+ * float(voxel_sizes[1])
376
+ * float(voxel_sizes[0])
377
+ )
378
+ return int(round(self.max_cluster_size_um3 / voxel_volume))
379
+
380
+ @cached_property
381
+ def ball_xy_size(self) -> int:
382
+ """The `ball_xy_size_um`, but in voxels."""
383
+ return int(round(self.ball_xy_size_um / self.in_plane_pixel_size))
384
+
385
+ @property
386
+ def z_pixel_size(self) -> float:
387
+ """Returns the um/pixel in the z direction."""
388
+ return self.voxel_sizes[0]
389
+
390
+ @cached_property
391
+ def ball_z_size(self) -> int:
392
+ """The `ball_z_size_um`, but in voxels."""
393
+ ball_z_size = int(round(self.ball_z_size_um / self.z_pixel_size))
394
+
395
+ if not ball_z_size:
396
+ raise ValueError(
397
+ "Ball z size has been calculated to be 0 voxels."
398
+ " This may be due to large axial spacing of your data or the "
399
+ "ball_z_size_um parameter being too small. "
400
+ "Please check input parameters are correct. "
401
+ "Note that cellfinder requires high resolution data in all "
402
+ "dimensions, so that cells can be detected in multiple "
403
+ "image planes."
404
+ )
405
+ return ball_z_size
406
+
407
+ @property
408
+ def max_cell_volume(self) -> float:
409
+ """
410
+ The maximum cell volume to consider as a single cell, in voxels.
411
+
412
+ If we find a bright area larger than that, we will split it.
413
+ """
414
+ radius = self.soma_spread_factor * self.soma_diameter / 2
415
+ return (4 / 3) * math.pi * radius**3
416
+
417
+ @property
418
+ def plane_prefix(self) -> str:
419
+ """
420
+ The prefix of the filename to use to save the 2d/3d filtered planes.
421
+
422
+ To save plane `k`, do `plane_prefix.format(n=k)`. You can then add
423
+ an extension etc.
424
+ """
425
+ n = max(4, int(math.ceil(math.log10(self.n_planes))))
426
+ name = f"plane_{{n:0{n}d}}"
427
+ return name