cellfinder 1.3.3__tar.gz → 1.4.0a0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cellfinder might be problematic. Click here for more details.

Files changed (81) hide show
  1. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/.github/workflows/test_and_deploy.yml +26 -4
  2. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/PKG-INFO +3 -2
  3. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/classify.py +1 -1
  4. cellfinder-1.4.0a0/cellfinder/core/detect/detect.py +236 -0
  5. cellfinder-1.4.0a0/cellfinder/core/detect/filters/plane/classical_filter.py +347 -0
  6. cellfinder-1.4.0a0/cellfinder/core/detect/filters/plane/plane_filter.py +169 -0
  7. cellfinder-1.4.0a0/cellfinder/core/detect/filters/plane/tile_walker.py +154 -0
  8. cellfinder-1.4.0a0/cellfinder/core/detect/filters/setup_filters.py +427 -0
  9. cellfinder-1.4.0a0/cellfinder/core/detect/filters/volume/ball_filter.py +415 -0
  10. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/detect/filters/volume/structure_detection.py +73 -35
  11. cellfinder-1.4.0a0/cellfinder/core/detect/filters/volume/structure_splitting.py +306 -0
  12. cellfinder-1.4.0a0/cellfinder/core/detect/filters/volume/volume_filter.py +523 -0
  13. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/main.py +6 -2
  14. cellfinder-1.4.0a0/cellfinder/core/tools/IO.py +45 -0
  15. cellfinder-1.4.0a0/cellfinder/core/tools/threading.py +380 -0
  16. cellfinder-1.4.0a0/cellfinder/core/tools/tools.py +295 -0
  17. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/PKG-INFO +3 -2
  18. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/SOURCES.txt +2 -0
  19. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/requires.txt +2 -1
  20. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/pyproject.toml +4 -1
  21. cellfinder-1.3.3/cellfinder/core/detect/detect.py +0 -301
  22. cellfinder-1.3.3/cellfinder/core/detect/filters/plane/classical_filter.py +0 -45
  23. cellfinder-1.3.3/cellfinder/core/detect/filters/plane/plane_filter.py +0 -87
  24. cellfinder-1.3.3/cellfinder/core/detect/filters/plane/tile_walker.py +0 -88
  25. cellfinder-1.3.3/cellfinder/core/detect/filters/setup_filters.py +0 -70
  26. cellfinder-1.3.3/cellfinder/core/detect/filters/volume/ball_filter.py +0 -417
  27. cellfinder-1.3.3/cellfinder/core/detect/filters/volume/structure_splitting.py +0 -242
  28. cellfinder-1.3.3/cellfinder/core/detect/filters/volume/volume_filter.py +0 -202
  29. cellfinder-1.3.3/cellfinder/core/tools/tools.py +0 -173
  30. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/.github/workflows/test_include_guard.yaml +0 -0
  31. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/.gitignore +0 -0
  32. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/.napari/config.yml +0 -0
  33. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/CITATION.cff +0 -0
  34. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/LICENSE +0 -0
  35. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/MANIFEST.in +0 -0
  36. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/README.md +0 -0
  37. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/__init__.py +0 -0
  38. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/cli_migration_warning.py +0 -0
  39. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/__init__.py +0 -0
  40. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/__init__.py +0 -0
  41. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/augment.py +0 -0
  42. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/cube_generator.py +0 -0
  43. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/resnet.py +0 -0
  44. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/tools.py +0 -0
  45. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/config/__init__.py +0 -0
  46. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/config/cellfinder.conf +0 -0
  47. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/detect/__init__.py +0 -0
  48. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/detect/filters/__init__.py +0 -0
  49. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/detect/filters/plane/__init__.py +0 -0
  50. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/detect/filters/volume/__init__.py +0 -0
  51. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/download/__init__.py +0 -0
  52. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/download/cli.py +0 -0
  53. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/download/download.py +0 -0
  54. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/__init__.py +0 -0
  55. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/array_operations.py +0 -0
  56. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/geometry.py +0 -0
  57. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/image_processing.py +0 -0
  58. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/prep.py +0 -0
  59. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/source_files.py +0 -0
  60. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/system.py +0 -0
  61. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/tiff.py +0 -0
  62. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/train/__init__.py +0 -0
  63. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/train/train_yml.py +0 -0
  64. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/types.py +0 -0
  65. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/__init__.py +0 -0
  66. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/curation.py +0 -0
  67. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/detect/__init__.py +0 -0
  68. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/detect/detect.py +0 -0
  69. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/detect/detect_containers.py +0 -0
  70. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/detect/thread_worker.py +0 -0
  71. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/input_container.py +0 -0
  72. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/napari.yaml +0 -0
  73. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/sample_data.py +0 -0
  74. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/train/__init__.py +0 -0
  75. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/train/train.py +0 -0
  76. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/train/train_containers.py +0 -0
  77. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/utils.py +0 -0
  78. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/dependency_links.txt +0 -0
  79. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/entry_points.txt +0 -0
  80. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/top_level.txt +0 -0
  81. {cellfinder-1.3.3 → cellfinder-1.4.0a0}/setup.cfg +0 -0
@@ -38,11 +38,14 @@ jobs:
38
38
  test:
39
39
  needs: [linting, manifest]
40
40
  name: Run package tests
41
- timeout-minutes: 60
41
+ timeout-minutes: 120
42
42
  runs-on: ${{ matrix.os }}
43
43
  env:
44
44
  KERAS_BACKEND: torch
45
45
  CELLFINDER_TEST_DEVICE: cpu
46
+ # pooch cache dir
47
+ BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache"
48
+
46
49
  strategy:
47
50
  matrix:
48
51
  # Run all supported Python versions on linux
@@ -56,6 +59,14 @@ jobs:
56
59
  python-version: "3.12"
57
60
 
58
61
  steps:
62
+ - uses: actions/checkout@v4
63
+ - name: Cache pooch data
64
+ uses: actions/cache@v4
65
+ with:
66
+ path: "~/.pooch_cache"
67
+ # hash on conftest in case url changes
68
+ key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pooch_registry.txt') }}
69
+ # Cache the tensorflow model so we don't have to remake it every time
59
70
  - name: Cache brainglobe directory
60
71
  uses: actions/cache@v3
61
72
  with:
@@ -78,12 +89,16 @@ jobs:
78
89
  test_numba_disabled:
79
90
  needs: [linting, manifest]
80
91
  name: Run tests with numba disabled
81
- timeout-minutes: 60
92
+ timeout-minutes: 120
82
93
  runs-on: ubuntu-latest
83
94
  env:
84
- NUMBA_DISABLE_JIT: "1"
95
+ NUMBA_DISABLE_JIT: "1"
96
+ PYTORCH_JIT: "0"
97
+ # pooch cache dir
98
+ BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache"
85
99
 
86
100
  steps:
101
+ - uses: actions/checkout@v4
87
102
  - name: Cache brainglobe directory
88
103
  uses: actions/cache@v3
89
104
  with:
@@ -91,6 +106,13 @@ jobs:
91
106
  ~/.brainglobe
92
107
  !~/.brainglobe/atlas.tar.gz
93
108
  key: brainglobe
109
+
110
+ - name: Cache pooch data
111
+ uses: actions/cache@v4
112
+ with:
113
+ path: "~/.pooch_cache"
114
+ key: ${{ runner.os }}-3.10-${{ hashFiles('**/pooch_registry.txt') }}
115
+
94
116
  # Setup pyqt libraries
95
117
  - name: Setup qtpy libraries
96
118
  uses: tlambert03/setup-qt-libs@v1
@@ -108,7 +130,7 @@ jobs:
108
130
  test_brainmapper_cli:
109
131
  needs: [linting, manifest]
110
132
  name: Run brainmapper tests to check for breakages
111
- timeout-minutes: 60
133
+ timeout-minutes: 120
112
134
  runs-on: ubuntu-latest
113
135
  env:
114
136
  KERAS_BACKEND: torch
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cellfinder
3
- Version: 1.3.3
3
+ Version: 1.4.0a0
4
4
  Summary: Automated 3D cell detection in large microscopy images
5
5
  Author-email: "Adam Tyson, Christian Niedworok, Charly Rousseau" <code@adamltyson.com>
6
6
  License: BSD-3-Clause
@@ -33,7 +33,7 @@ Requires-Dist: numpy
33
33
  Requires-Dist: scikit-image
34
34
  Requires-Dist: scikit-learn
35
35
  Requires-Dist: keras==3.5.0
36
- Requires-Dist: torch>=2.1.0
36
+ Requires-Dist: torch!=2.4,>=2.1.0
37
37
  Requires-Dist: tifffile
38
38
  Requires-Dist: tqdm
39
39
  Provides-Extra: dev
@@ -46,6 +46,7 @@ Requires-Dist: pytest-qt; extra == "dev"
46
46
  Requires-Dist: pytest-timeout; extra == "dev"
47
47
  Requires-Dist: pytest; extra == "dev"
48
48
  Requires-Dist: tox; extra == "dev"
49
+ Requires-Dist: pooch>=1; extra == "dev"
49
50
  Provides-Extra: napari
50
51
  Requires-Dist: brainglobe-napari-io; extra == "napari"
51
52
  Requires-Dist: magicgui; extra == "napari"
@@ -30,7 +30,7 @@ def main(
30
30
  max_workers: int = 3,
31
31
  *,
32
32
  callback: Optional[Callable[[int], None]] = None,
33
- ) -> List:
33
+ ) -> List[Cell]:
34
34
  """
35
35
  Parameters
36
36
  ----------
@@ -0,0 +1,236 @@
1
+ """
2
+ Detection is run in three steps:
3
+
4
+ 1. 2D filtering
5
+ 2. 3D filtering
6
+ 3. Structure detection
7
+
8
+ In steps 1. and 2. filters are applied, and any bright points detected
9
+ post-filter are marked. To avoid using a separate mask array to mark the
10
+ bright points, the input data is clipped to [0, (max_val - 2)]
11
+ (max_val is the maximum value that the image data type can store), and:
12
+ - (max_val - 1) is used to mark bright points during 2D filtering
13
+ - (max_val) is used to mark bright points during 3D filtering
14
+ """
15
+
16
+ import dataclasses
17
+ from datetime import datetime
18
+ from typing import Callable, List, Optional, Tuple
19
+
20
+ import numpy as np
21
+ import torch
22
+ from brainglobe_utils.cells.cells import Cell
23
+
24
+ from cellfinder.core import logger, types
25
+ from cellfinder.core.detect.filters.plane import TileProcessor
26
+ from cellfinder.core.detect.filters.setup_filters import DetectionSettings
27
+ from cellfinder.core.detect.filters.volume.volume_filter import VolumeFilter
28
+ from cellfinder.core.tools.tools import inference_wrapper
29
+
30
+
31
+ @inference_wrapper
32
+ def main(
33
+ signal_array: types.array,
34
+ start_plane: int = 0,
35
+ end_plane: int = -1,
36
+ voxel_sizes: Tuple[float, float, float] = (5, 2, 2),
37
+ soma_diameter: float = 16,
38
+ max_cluster_size: float = 100_000,
39
+ ball_xy_size: float = 6,
40
+ ball_z_size: float = 15,
41
+ ball_overlap_fraction: float = 0.6,
42
+ soma_spread_factor: float = 1.4,
43
+ n_free_cpus: int = 2,
44
+ log_sigma_size: float = 0.2,
45
+ n_sds_above_mean_thresh: float = 10,
46
+ outlier_keep: bool = False,
47
+ artifact_keep: bool = False,
48
+ save_planes: bool = False,
49
+ plane_directory: Optional[str] = None,
50
+ batch_size: Optional[int] = None,
51
+ torch_device: str = "cpu",
52
+ use_scipy: bool = True,
53
+ split_ball_xy_size: int = 3,
54
+ split_ball_z_size: int = 3,
55
+ split_ball_overlap_fraction: float = 0.8,
56
+ split_soma_diameter: int = 7,
57
+ *,
58
+ callback: Optional[Callable[[int], None]] = None,
59
+ ) -> List[Cell]:
60
+ """
61
+ Perform cell candidate detection on a 3D signal array.
62
+
63
+ Parameters
64
+ ----------
65
+ signal_array : numpy.ndarray
66
+ 3D array representing the signal data.
67
+
68
+ start_plane : int
69
+ Index of the starting plane for detection.
70
+
71
+ end_plane : int
72
+ Index of the ending plane for detection.
73
+
74
+ voxel_sizes : Tuple[float, float, float]
75
+ Tuple of voxel sizes in each dimension (z, y, x).
76
+
77
+ soma_diameter : float
78
+ Diameter of the soma in physical units.
79
+
80
+ max_cluster_size : float
81
+ Maximum size of a cluster in physical units.
82
+
83
+ ball_xy_size : float
84
+ Size of the XY ball used for filtering in physical units.
85
+
86
+ ball_z_size : float
87
+ Size of the Z ball used for filtering in physical units.
88
+
89
+ ball_overlap_fraction : float
90
+ Fraction of overlap allowed between balls.
91
+
92
+ soma_spread_factor : float
93
+ Spread factor for soma size.
94
+
95
+ n_free_cpus : int
96
+ Number of free CPU cores available for parallel processing.
97
+
98
+ log_sigma_size : float
99
+ Size of the sigma for the log filter.
100
+
101
+ n_sds_above_mean_thresh : float
102
+ Number of standard deviations above the mean threshold.
103
+
104
+ outlier_keep : bool, optional
105
+ Whether to keep outliers during detection. Defaults to False.
106
+
107
+ artifact_keep : bool, optional
108
+ Whether to keep artifacts during detection. Defaults to False.
109
+
110
+ save_planes : bool, optional
111
+ Whether to save the planes during detection. Defaults to False.
112
+
113
+ plane_directory : str, optional
114
+ Directory path to save the planes. Defaults to None.
115
+
116
+ batch_size : int, optional
117
+ The number of planes to process in each batch. Defaults to 1.
118
+ For CPU, there's no benefit for a larger batch size. Only a memory
119
+ usage increase. For CUDA, the larger the batch size the better the
120
+ performance. Until it fills up the GPU memory - after which it
121
+ becomes slower.
122
+
123
+ torch_device : str, optional
124
+ The device on which to run the computation. By default, it's "cpu".
125
+ To run on a gpu, specify the PyTorch device name, such as "cuda" to
126
+ run on the first GPU.
127
+
128
+ callback : Callable[int], optional
129
+ A callback function that is called every time a plane has finished
130
+ being processed. Called with the plane number that has finished.
131
+
132
+ Returns
133
+ -------
134
+ List[Cell]
135
+ List of detected cells.
136
+ """
137
+ start_time = datetime.now()
138
+ if batch_size is None:
139
+ if torch_device == "cpu":
140
+ batch_size = 4
141
+ else:
142
+ batch_size = 1
143
+
144
+ if not np.issubdtype(signal_array.dtype, np.number):
145
+ raise TypeError(
146
+ "signal_array must be a numpy datatype, but has datatype "
147
+ f"{signal_array.dtype}"
148
+ )
149
+
150
+ if signal_array.ndim != 3:
151
+ raise ValueError("Input data must be 3D")
152
+
153
+ if end_plane < 0:
154
+ end_plane = len(signal_array)
155
+ end_plane = min(len(signal_array), end_plane)
156
+
157
+ torch_device = torch_device.lower()
158
+ batch_size = max(batch_size, 1)
159
+ # brainmapper can pass them in as str
160
+ voxel_sizes = list(map(float, voxel_sizes))
161
+
162
+ settings = DetectionSettings(
163
+ plane_shape=signal_array.shape[1:],
164
+ plane_original_np_dtype=signal_array.dtype,
165
+ voxel_sizes=voxel_sizes,
166
+ soma_spread_factor=soma_spread_factor,
167
+ soma_diameter_um=soma_diameter,
168
+ max_cluster_size_um3=max_cluster_size,
169
+ ball_xy_size_um=ball_xy_size,
170
+ ball_z_size_um=ball_z_size,
171
+ start_plane=start_plane,
172
+ end_plane=end_plane,
173
+ n_free_cpus=n_free_cpus,
174
+ ball_overlap_fraction=ball_overlap_fraction,
175
+ log_sigma_size=log_sigma_size,
176
+ n_sds_above_mean_thresh=n_sds_above_mean_thresh,
177
+ outlier_keep=outlier_keep,
178
+ artifact_keep=artifact_keep,
179
+ save_planes=save_planes,
180
+ plane_directory=plane_directory,
181
+ batch_size=batch_size,
182
+ torch_device=torch_device,
183
+ )
184
+
185
+ # replicate the settings specific to splitting, before we access anything
186
+ # of the original settings, causing cached properties
187
+ kwargs = dataclasses.asdict(settings)
188
+ kwargs["ball_z_size_um"] = split_ball_z_size * settings.z_pixel_size
189
+ kwargs["ball_xy_size_um"] = (
190
+ split_ball_xy_size * settings.in_plane_pixel_size
191
+ )
192
+ kwargs["ball_overlap_fraction"] = split_ball_overlap_fraction
193
+ kwargs["soma_diameter_um"] = (
194
+ split_soma_diameter * settings.in_plane_pixel_size
195
+ )
196
+ # always run on cpu because copying to gpu overhead is likely slower than
197
+ # any benefit for detection on smallish volumes
198
+ kwargs["torch_device"] = "cpu"
199
+ # for splitting, we only do 3d filtering. Its input is a zero volume
200
+ # with cell voxels marked with threshold_value. So just use float32
201
+ # for input because the filters will also use float(32). So there will
202
+ # not be need to convert the input a different dtype before passing to
203
+ # the filters.
204
+ kwargs["plane_original_np_dtype"] = np.float32
205
+ splitting_settings = DetectionSettings(**kwargs)
206
+
207
+ # Create 3D analysis filter
208
+ mp_3d_filter = VolumeFilter(settings=settings)
209
+
210
+ # Create 2D analysis filter
211
+ mp_tile_processor = TileProcessor(
212
+ plane_shape=settings.plane_shape,
213
+ clipping_value=settings.clipping_value,
214
+ threshold_value=settings.threshold_value,
215
+ n_sds_above_mean_thresh=n_sds_above_mean_thresh,
216
+ log_sigma_size=log_sigma_size,
217
+ soma_diameter=settings.soma_diameter,
218
+ torch_device=torch_device,
219
+ dtype=settings.filtering_dtype.__name__,
220
+ use_scipy=use_scipy,
221
+ )
222
+
223
+ orig_n_threads = torch.get_num_threads()
224
+ torch.set_num_threads(settings.n_torch_comp_threads)
225
+
226
+ # process the data
227
+ mp_3d_filter.process(mp_tile_processor, signal_array, callback=callback)
228
+ cells = mp_3d_filter.get_results(splitting_settings)
229
+
230
+ torch.set_num_threads(orig_n_threads)
231
+
232
+ time_elapsed = datetime.now() - start_time
233
+ s = f"Detection complete. Found {len(cells)} cells in {time_elapsed}"
234
+ logger.debug(s)
235
+ print(s)
236
+ return cells
@@ -0,0 +1,347 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from scipy.ndimage import gaussian_filter, laplace
5
+ from scipy.signal import medfilt2d
6
+
7
+
8
+ @torch.jit.script
9
+ def normalize(
10
+ filtered_planes: torch.Tensor,
11
+ flip: bool,
12
+ max_value: float = 1.0,
13
+ ) -> None:
14
+ """
15
+ Normalizes the 3d tensor so each z-plane is independently scaled to be
16
+ in the [0, max_value] range. If `flip` is `True`, the sign of the tensor
17
+ values are flipped before any processing.
18
+
19
+ It is done to filtered_planes inplace.
20
+ """
21
+ num_z = filtered_planes.shape[0]
22
+ filtered_planes_1d = filtered_planes.view(num_z, -1)
23
+
24
+ if flip:
25
+ filtered_planes_1d.mul_(-1)
26
+
27
+ planes_min = torch.min(filtered_planes_1d, dim=1, keepdim=True)[0]
28
+ filtered_planes_1d.sub_(planes_min)
29
+ # take max after subtraction
30
+ planes_max = torch.max(filtered_planes_1d, dim=1, keepdim=True)[0]
31
+ # if min = max = zero, divide by 1 - it'll stay zero
32
+ planes_max[planes_max == 0] = 1
33
+ filtered_planes_1d.div_(planes_max)
34
+
35
+ if max_value != 1.0:
36
+ # To leave room to label in the 3d detection.
37
+ filtered_planes_1d.mul_(max_value)
38
+
39
+
40
+ @torch.jit.script
41
+ def filter_for_peaks(
42
+ planes: torch.Tensor,
43
+ med_kernel: torch.Tensor,
44
+ gauss_kernel: torch.Tensor,
45
+ gauss_kernel_size: int,
46
+ lap_kernel: torch.Tensor,
47
+ device: str,
48
+ clipping_value: float,
49
+ ) -> torch.Tensor:
50
+ """
51
+ Takes the 3d z-stack and returns a new z-stack where the peaks are
52
+ highlighted.
53
+
54
+ It applies a median filter -> gaussian filter -> laplacian filter.
55
+ """
56
+ filtered_planes = planes.unsqueeze(1) # ZYX -> ZCYX input, C=channels
57
+
58
+ # ------------------ median filter ------------------
59
+ # extracts patches to compute median over for each pixel
60
+ # We go from ZCYX -> ZCYX, C=1 to C=9 with C containing the elements around
61
+ # each Z,X,Y voxel over which we compute the median
62
+ # Zero padding is ok here
63
+ filtered_planes = F.conv2d(filtered_planes, med_kernel, padding="same")
64
+ # we're going back to ZCYX=Z1YX by taking median of patches in C dim
65
+ filtered_planes = filtered_planes.median(dim=1, keepdim=True)[0]
66
+
67
+ # ------------------ gaussian filter ------------------
68
+ # normalize the input data to 0-1 range. Otherwise, if the values are
69
+ # large, we'd need a float64 so conv result is accurate
70
+ normalize(filtered_planes, flip=False)
71
+
72
+ # we need to do reflection padding around the tensor for parity with scipy
73
+ # gaussian filtering. Scipy does reflection in a manner typically called
74
+ # symmetric: (dcba|abcd|dcba). Torch does it like this: (dcb|abcd|cba). So
75
+ # we manually do symmetric padding below
76
+ pad = gauss_kernel_size // 2
77
+ padding_mode = "reflect"
78
+ # if data is too small for reflect, just use constant border value
79
+ if pad >= filtered_planes.shape[-1] or pad >= filtered_planes.shape[-2]:
80
+ padding_mode = "replicate"
81
+ filtered_planes = F.pad(filtered_planes, (pad,) * 4, padding_mode, 0.0)
82
+ # We reflected torch style, so copy/shift everything by one to be symmetric
83
+ filtered_planes[:, :, :pad, :] = filtered_planes[
84
+ :, :, 1 : pad + 1, :
85
+ ].clone()
86
+ filtered_planes[:, :, -pad:, :] = filtered_planes[
87
+ :, :, -pad - 1 : -1, :
88
+ ].clone()
89
+ filtered_planes[:, :, :, :pad] = filtered_planes[
90
+ :, :, :, 1 : pad + 1
91
+ ].clone()
92
+ filtered_planes[:, :, :, -pad:] = filtered_planes[
93
+ :, :, :, -pad - 1 : -1
94
+ ].clone()
95
+
96
+ # We apply the 1D gaussian filter twice, once for Y and once for X. The
97
+ # filter shape passed in is 11K1 or 111K, depending on device. Where
98
+ # K=filter size
99
+ # see https://discuss.pytorch.org/t/performance-issue-for-conv2d-with-1d-
100
+ # filter-along-a-dim/201734/2 for the reason for the moveaxis depending
101
+ # on the device
102
+ if device == "cpu":
103
+ # kernel shape is 11K1. First do Y (second to last axis)
104
+ filtered_planes = F.conv2d(
105
+ filtered_planes, gauss_kernel, padding="valid"
106
+ )
107
+ # To do X, exchange X,Y axis, filter, change back. On CPU, Y (second
108
+ # to last) axis is faster.
109
+ filtered_planes = F.conv2d(
110
+ filtered_planes.moveaxis(-1, -2), gauss_kernel, padding="valid"
111
+ ).moveaxis(-1, -2)
112
+ else:
113
+ # kernel shape is 111K
114
+ # First do Y (second to last axis). Exchange X,Y axis, filter, change
115
+ # back. On CUDA, X (last) axis is faster.
116
+ filtered_planes = F.conv2d(
117
+ filtered_planes.moveaxis(-1, -2), gauss_kernel, padding="valid"
118
+ ).moveaxis(-1, -2)
119
+ # now do X, last axis
120
+ filtered_planes = F.conv2d(
121
+ filtered_planes, gauss_kernel, padding="valid"
122
+ )
123
+
124
+ # ------------------ laplacian filter ------------------
125
+ # it's a 2d filter. Need to pad using symmetric for scipy parity. But,
126
+ # torch doesn't have it, and we used a kernel of size 3, so for padding of
127
+ # 1, replicate == symmetric. That's enough for parity with past scipy. If
128
+ # we change kernel size in the future, we may have to do as above
129
+ padding = lap_kernel.shape[-1] // 2
130
+ filtered_planes = F.pad(filtered_planes, (padding,) * 4, "replicate")
131
+ filtered_planes = F.conv2d(filtered_planes, lap_kernel, padding="valid")
132
+
133
+ # we don't need the channel axis
134
+ filtered_planes = filtered_planes[:, 0, :, :]
135
+
136
+ # scale back to full scale, filtered values are negative so flip
137
+ normalize(filtered_planes, flip=True, max_value=clipping_value)
138
+ return filtered_planes
139
+
140
+
141
+ class PeakEnhancer:
142
+ """
143
+ A class that filters each plane in a z-stack such that peaks are
144
+ visualized.
145
+
146
+ It uses a series of 2D filters of median -> gaussian ->
147
+ laplacian. Then normalizes each plane to be between [0, clipping_value].
148
+
149
+ Parameters
150
+ ----------
151
+ torch_device: str
152
+ The device on which the data and processing occurs on. Can be e.g.
153
+ "cpu", "cuda" etc. Any data passed to the filter must be on this
154
+ device. Returned data will also be on this device.
155
+ dtype : torch.dtype
156
+ The data-type of the input planes and the type to use internally.
157
+ E.g. `torch.float32`.
158
+ clipping_value : int
159
+ The value such that after normalizing, the max value will be this
160
+ clipping_value.
161
+ laplace_gaussian_sigma : float
162
+ Size of the sigma for the gaussian filter.
163
+ use_scipy : bool
164
+ If running on the CPU whether to use the scipy filters or the same
165
+ pytorch filters used on CUDA. Scipy filters can be faster.
166
+ """
167
+
168
+ # binary kernel that generates square patches for each pixel so we can find
169
+ # the median around the pixel
170
+ med_kernel: torch.Tensor
171
+
172
+ # gaussian 1D kernel with kernel/weight shape 11K1 or 111K, depending
173
+ # on device. Where K=filter size
174
+ gauss_kernel: torch.Tensor
175
+
176
+ # 2D laplacian kernel with kernel/weight shape KxK. Where
177
+ # K=filter size
178
+ lap_kernel: torch.Tensor
179
+
180
+ # the value such that after normalizing, the max value will be this
181
+ # clipping_value
182
+ clipping_value: float
183
+
184
+ # sigma value for gaussian filter
185
+ laplace_gaussian_sigma: float
186
+
187
+ # the torch device to run on. E.g. cpu/cuda.
188
+ torch_device: str
189
+
190
+ # when running on CPU whether to use pytorch or scipy for filters
191
+ use_scipy: bool
192
+
193
+ median_filter_size: int = 3
194
+ """
195
+ The median filter size in x/y direction.
196
+
197
+ **Must** be odd.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ torch_device: str,
203
+ dtype: torch.dtype,
204
+ clipping_value: float,
205
+ laplace_gaussian_sigma: float,
206
+ use_scipy: bool,
207
+ ):
208
+ super().__init__()
209
+ self.torch_device = torch_device.lower()
210
+ self.clipping_value = clipping_value
211
+ self.laplace_gaussian_sigma = laplace_gaussian_sigma
212
+ self.use_scipy = use_scipy
213
+
214
+ # all these kernels are odd in size
215
+ self.med_kernel = self._get_median_kernel(torch_device, dtype)
216
+ self.gauss_kernel = self._get_gaussian_kernel(
217
+ torch_device, dtype, laplace_gaussian_sigma
218
+ )
219
+ self.lap_kernel = self._get_laplacian_kernel(torch_device, dtype)
220
+
221
+ @property
222
+ def gaussian_filter_size(self) -> int:
223
+ """
224
+ The gaussian filter 1d size.
225
+
226
+ It is odd.
227
+ """
228
+ return 2 * int(round(4 * self.laplace_gaussian_sigma)) + 1
229
+
230
+ def _get_median_kernel(
231
+ self, torch_device: str, dtype: torch.dtype
232
+ ) -> torch.Tensor:
233
+ """
234
+ Gets a median patch generator kernel, already on the correct
235
+ device.
236
+
237
+ Based on how kornia does it for median filtering.
238
+ """
239
+ # must be odd kernel
240
+ kernel_n = self.median_filter_size
241
+ if not (kernel_n % 2):
242
+ raise ValueError("The median filter size must be odd")
243
+
244
+ # extract patches to compute median over for each pixel. When passing
245
+ # input we go from ZCYX -> ZCYX, C=1 to C=9 and containing the elements
246
+ # around each Z,X,Y over which we can then compute the median
247
+ window_range = kernel_n * kernel_n # e.g. 3x3
248
+ kernel = torch.zeros(
249
+ (window_range, window_range), device=torch_device, dtype=dtype
250
+ )
251
+ idx = torch.arange(window_range, device=torch_device)
252
+ # diagonal of e.g. 9x9 is 1
253
+ kernel[idx, idx] = 1.0
254
+ # out channels, in channels, n*y, n*x. The kernel collects all the 3x3
255
+ # elements around a pixel, using a binary mask for each element, as a
256
+ # separate channel. So we go from 1 to 9 channels in the output
257
+ kernel = kernel.view(window_range, 1, kernel_n, kernel_n)
258
+
259
+ return kernel
260
+
261
+ def _get_gaussian_kernel(
262
+ self,
263
+ torch_device: str,
264
+ dtype: torch.dtype,
265
+ laplace_gaussian_sigma: float,
266
+ ) -> torch.Tensor:
267
+ """Gets the 1D gaussian kernel used to filter the data."""
268
+ # we do 2 1D filters, once on each y, x dim.
269
+ # shape of kernel will be 11K1 with dims Z, C, Y, X. C=1, Z is expanded
270
+ # to number of z during filtering.
271
+ kernel_size = self.gaussian_filter_size
272
+
273
+ # to get the values of a 1D gaussian kernel, we pass a single impulse
274
+ # data through the filter, which recovers the filter values. We do this
275
+ # because scipy doesn't make their kernel available in public API and
276
+ # we want parity with scipy filtering
277
+ impulse = np.zeros(kernel_size)
278
+ # the impulse needs to be to the left of center
279
+ impulse[kernel_size // 2] = 1
280
+ kernel = gaussian_filter(
281
+ impulse, laplace_gaussian_sigma, mode="constant", cval=0
282
+ )
283
+ # kernel should be fully symmetric
284
+ assert kernel[0] == kernel[-1]
285
+ gauss_kernel = torch.from_numpy(kernel).type(dtype).to(torch_device)
286
+
287
+ # default shape is (y, x) with y axis filtered only - we transpose
288
+ # input to filter on x
289
+ gauss_kernel = gauss_kernel.view(1, 1, -1, 1)
290
+
291
+ # see https://discuss.pytorch.org/t/performance-issue-for-conv2d-
292
+ # with-1d-filter-along-a-dim/201734. Conv2d is faster on a specific dim
293
+ # for 1D filters depending on CPU/CUDA. See also filter_for_peaks
294
+ # on CPU, we only do conv2d on the (1st) dim
295
+ if torch_device != "cpu":
296
+ # on CUDA, we only filter on the x dim, flipping input to filter y
297
+ gauss_kernel = gauss_kernel.view(1, 1, 1, -1)
298
+
299
+ return gauss_kernel
300
+
301
+ def _get_laplacian_kernel(
302
+ self, torch_device: str, dtype: torch.dtype
303
+ ) -> torch.Tensor:
304
+ """Gets a 2d laplacian kernel, based on scipy's laplace."""
305
+ # for parity with scipy, scipy computes the laplacian with default
306
+ # parameters and kernel size 3 using filter coefficients [1, -2, 1].
307
+ # Each filtered pixel is the sum of the filter around the pixel
308
+ # vertically and horizontally. We can do it in 2d at once with
309
+ # coefficients below (faster than 2x1D for such small filter)
310
+ return torch.as_tensor(
311
+ [[0, 1, 0], [1, -4, 1], [0, 1, 0]],
312
+ dtype=dtype,
313
+ device=torch_device,
314
+ ).view(1, 1, 3, 3)
315
+
316
+ def enhance_peaks(self, planes: torch.Tensor) -> torch.Tensor:
317
+ """
318
+ Applies the filtering and normalization to the 3d z-stack (not inplace)
319
+ and returns the filtered z-stack.
320
+ """
321
+ if self.torch_device == "cpu" and self.use_scipy:
322
+ filtered_planes = planes.clone()
323
+ for i in range(planes.shape[0]):
324
+ img = planes[i, :, :].numpy()
325
+ img = medfilt2d(img)
326
+ img = gaussian_filter(img, self.laplace_gaussian_sigma)
327
+ img = laplace(img)
328
+ filtered_planes[i, :, :] = torch.from_numpy(img)
329
+
330
+ # laplace makes values negative so flip
331
+ normalize(
332
+ filtered_planes,
333
+ flip=True,
334
+ max_value=self.clipping_value,
335
+ )
336
+ return filtered_planes
337
+
338
+ filtered_planes = filter_for_peaks(
339
+ planes,
340
+ self.med_kernel,
341
+ self.gauss_kernel,
342
+ self.gaussian_filter_size,
343
+ self.lap_kernel,
344
+ self.torch_device,
345
+ self.clipping_value,
346
+ )
347
+ return filtered_planes