cellfinder 1.3.3__tar.gz → 1.4.0a0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cellfinder might be problematic. Click here for more details.
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/.github/workflows/test_and_deploy.yml +26 -4
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/PKG-INFO +3 -2
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/classify.py +1 -1
- cellfinder-1.4.0a0/cellfinder/core/detect/detect.py +236 -0
- cellfinder-1.4.0a0/cellfinder/core/detect/filters/plane/classical_filter.py +347 -0
- cellfinder-1.4.0a0/cellfinder/core/detect/filters/plane/plane_filter.py +169 -0
- cellfinder-1.4.0a0/cellfinder/core/detect/filters/plane/tile_walker.py +154 -0
- cellfinder-1.4.0a0/cellfinder/core/detect/filters/setup_filters.py +427 -0
- cellfinder-1.4.0a0/cellfinder/core/detect/filters/volume/ball_filter.py +415 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/detect/filters/volume/structure_detection.py +73 -35
- cellfinder-1.4.0a0/cellfinder/core/detect/filters/volume/structure_splitting.py +306 -0
- cellfinder-1.4.0a0/cellfinder/core/detect/filters/volume/volume_filter.py +523 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/main.py +6 -2
- cellfinder-1.4.0a0/cellfinder/core/tools/IO.py +45 -0
- cellfinder-1.4.0a0/cellfinder/core/tools/threading.py +380 -0
- cellfinder-1.4.0a0/cellfinder/core/tools/tools.py +295 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/PKG-INFO +3 -2
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/SOURCES.txt +2 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/requires.txt +2 -1
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/pyproject.toml +4 -1
- cellfinder-1.3.3/cellfinder/core/detect/detect.py +0 -301
- cellfinder-1.3.3/cellfinder/core/detect/filters/plane/classical_filter.py +0 -45
- cellfinder-1.3.3/cellfinder/core/detect/filters/plane/plane_filter.py +0 -87
- cellfinder-1.3.3/cellfinder/core/detect/filters/plane/tile_walker.py +0 -88
- cellfinder-1.3.3/cellfinder/core/detect/filters/setup_filters.py +0 -70
- cellfinder-1.3.3/cellfinder/core/detect/filters/volume/ball_filter.py +0 -417
- cellfinder-1.3.3/cellfinder/core/detect/filters/volume/structure_splitting.py +0 -242
- cellfinder-1.3.3/cellfinder/core/detect/filters/volume/volume_filter.py +0 -202
- cellfinder-1.3.3/cellfinder/core/tools/tools.py +0 -173
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/.github/workflows/test_include_guard.yaml +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/.gitignore +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/.napari/config.yml +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/CITATION.cff +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/LICENSE +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/MANIFEST.in +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/README.md +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/cli_migration_warning.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/augment.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/cube_generator.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/resnet.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/classify/tools.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/config/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/config/cellfinder.conf +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/detect/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/detect/filters/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/detect/filters/plane/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/detect/filters/volume/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/download/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/download/cli.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/download/download.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/array_operations.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/geometry.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/image_processing.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/prep.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/source_files.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/system.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/tools/tiff.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/train/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/train/train_yml.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/core/types.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/curation.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/detect/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/detect/detect.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/detect/detect_containers.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/detect/thread_worker.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/input_container.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/napari.yaml +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/sample_data.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/train/__init__.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/train/train.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/train/train_containers.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder/napari/utils.py +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/dependency_links.txt +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/entry_points.txt +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/cellfinder.egg-info/top_level.txt +0 -0
- {cellfinder-1.3.3 → cellfinder-1.4.0a0}/setup.cfg +0 -0
|
@@ -38,11 +38,14 @@ jobs:
|
|
|
38
38
|
test:
|
|
39
39
|
needs: [linting, manifest]
|
|
40
40
|
name: Run package tests
|
|
41
|
-
timeout-minutes:
|
|
41
|
+
timeout-minutes: 120
|
|
42
42
|
runs-on: ${{ matrix.os }}
|
|
43
43
|
env:
|
|
44
44
|
KERAS_BACKEND: torch
|
|
45
45
|
CELLFINDER_TEST_DEVICE: cpu
|
|
46
|
+
# pooch cache dir
|
|
47
|
+
BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache"
|
|
48
|
+
|
|
46
49
|
strategy:
|
|
47
50
|
matrix:
|
|
48
51
|
# Run all supported Python versions on linux
|
|
@@ -56,6 +59,14 @@ jobs:
|
|
|
56
59
|
python-version: "3.12"
|
|
57
60
|
|
|
58
61
|
steps:
|
|
62
|
+
- uses: actions/checkout@v4
|
|
63
|
+
- name: Cache pooch data
|
|
64
|
+
uses: actions/cache@v4
|
|
65
|
+
with:
|
|
66
|
+
path: "~/.pooch_cache"
|
|
67
|
+
# hash on conftest in case url changes
|
|
68
|
+
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pooch_registry.txt') }}
|
|
69
|
+
# Cache the tensorflow model so we don't have to remake it every time
|
|
59
70
|
- name: Cache brainglobe directory
|
|
60
71
|
uses: actions/cache@v3
|
|
61
72
|
with:
|
|
@@ -78,12 +89,16 @@ jobs:
|
|
|
78
89
|
test_numba_disabled:
|
|
79
90
|
needs: [linting, manifest]
|
|
80
91
|
name: Run tests with numba disabled
|
|
81
|
-
timeout-minutes:
|
|
92
|
+
timeout-minutes: 120
|
|
82
93
|
runs-on: ubuntu-latest
|
|
83
94
|
env:
|
|
84
|
-
|
|
95
|
+
NUMBA_DISABLE_JIT: "1"
|
|
96
|
+
PYTORCH_JIT: "0"
|
|
97
|
+
# pooch cache dir
|
|
98
|
+
BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache"
|
|
85
99
|
|
|
86
100
|
steps:
|
|
101
|
+
- uses: actions/checkout@v4
|
|
87
102
|
- name: Cache brainglobe directory
|
|
88
103
|
uses: actions/cache@v3
|
|
89
104
|
with:
|
|
@@ -91,6 +106,13 @@ jobs:
|
|
|
91
106
|
~/.brainglobe
|
|
92
107
|
!~/.brainglobe/atlas.tar.gz
|
|
93
108
|
key: brainglobe
|
|
109
|
+
|
|
110
|
+
- name: Cache pooch data
|
|
111
|
+
uses: actions/cache@v4
|
|
112
|
+
with:
|
|
113
|
+
path: "~/.pooch_cache"
|
|
114
|
+
key: ${{ runner.os }}-3.10-${{ hashFiles('**/pooch_registry.txt') }}
|
|
115
|
+
|
|
94
116
|
# Setup pyqt libraries
|
|
95
117
|
- name: Setup qtpy libraries
|
|
96
118
|
uses: tlambert03/setup-qt-libs@v1
|
|
@@ -108,7 +130,7 @@ jobs:
|
|
|
108
130
|
test_brainmapper_cli:
|
|
109
131
|
needs: [linting, manifest]
|
|
110
132
|
name: Run brainmapper tests to check for breakages
|
|
111
|
-
timeout-minutes:
|
|
133
|
+
timeout-minutes: 120
|
|
112
134
|
runs-on: ubuntu-latest
|
|
113
135
|
env:
|
|
114
136
|
KERAS_BACKEND: torch
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: cellfinder
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.0a0
|
|
4
4
|
Summary: Automated 3D cell detection in large microscopy images
|
|
5
5
|
Author-email: "Adam Tyson, Christian Niedworok, Charly Rousseau" <code@adamltyson.com>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -33,7 +33,7 @@ Requires-Dist: numpy
|
|
|
33
33
|
Requires-Dist: scikit-image
|
|
34
34
|
Requires-Dist: scikit-learn
|
|
35
35
|
Requires-Dist: keras==3.5.0
|
|
36
|
-
Requires-Dist: torch
|
|
36
|
+
Requires-Dist: torch!=2.4,>=2.1.0
|
|
37
37
|
Requires-Dist: tifffile
|
|
38
38
|
Requires-Dist: tqdm
|
|
39
39
|
Provides-Extra: dev
|
|
@@ -46,6 +46,7 @@ Requires-Dist: pytest-qt; extra == "dev"
|
|
|
46
46
|
Requires-Dist: pytest-timeout; extra == "dev"
|
|
47
47
|
Requires-Dist: pytest; extra == "dev"
|
|
48
48
|
Requires-Dist: tox; extra == "dev"
|
|
49
|
+
Requires-Dist: pooch>=1; extra == "dev"
|
|
49
50
|
Provides-Extra: napari
|
|
50
51
|
Requires-Dist: brainglobe-napari-io; extra == "napari"
|
|
51
52
|
Requires-Dist: magicgui; extra == "napari"
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Detection is run in three steps:
|
|
3
|
+
|
|
4
|
+
1. 2D filtering
|
|
5
|
+
2. 3D filtering
|
|
6
|
+
3. Structure detection
|
|
7
|
+
|
|
8
|
+
In steps 1. and 2. filters are applied, and any bright points detected
|
|
9
|
+
post-filter are marked. To avoid using a separate mask array to mark the
|
|
10
|
+
bright points, the input data is clipped to [0, (max_val - 2)]
|
|
11
|
+
(max_val is the maximum value that the image data type can store), and:
|
|
12
|
+
- (max_val - 1) is used to mark bright points during 2D filtering
|
|
13
|
+
- (max_val) is used to mark bright points during 3D filtering
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
from typing import Callable, List, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import torch
|
|
22
|
+
from brainglobe_utils.cells.cells import Cell
|
|
23
|
+
|
|
24
|
+
from cellfinder.core import logger, types
|
|
25
|
+
from cellfinder.core.detect.filters.plane import TileProcessor
|
|
26
|
+
from cellfinder.core.detect.filters.setup_filters import DetectionSettings
|
|
27
|
+
from cellfinder.core.detect.filters.volume.volume_filter import VolumeFilter
|
|
28
|
+
from cellfinder.core.tools.tools import inference_wrapper
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@inference_wrapper
|
|
32
|
+
def main(
|
|
33
|
+
signal_array: types.array,
|
|
34
|
+
start_plane: int = 0,
|
|
35
|
+
end_plane: int = -1,
|
|
36
|
+
voxel_sizes: Tuple[float, float, float] = (5, 2, 2),
|
|
37
|
+
soma_diameter: float = 16,
|
|
38
|
+
max_cluster_size: float = 100_000,
|
|
39
|
+
ball_xy_size: float = 6,
|
|
40
|
+
ball_z_size: float = 15,
|
|
41
|
+
ball_overlap_fraction: float = 0.6,
|
|
42
|
+
soma_spread_factor: float = 1.4,
|
|
43
|
+
n_free_cpus: int = 2,
|
|
44
|
+
log_sigma_size: float = 0.2,
|
|
45
|
+
n_sds_above_mean_thresh: float = 10,
|
|
46
|
+
outlier_keep: bool = False,
|
|
47
|
+
artifact_keep: bool = False,
|
|
48
|
+
save_planes: bool = False,
|
|
49
|
+
plane_directory: Optional[str] = None,
|
|
50
|
+
batch_size: Optional[int] = None,
|
|
51
|
+
torch_device: str = "cpu",
|
|
52
|
+
use_scipy: bool = True,
|
|
53
|
+
split_ball_xy_size: int = 3,
|
|
54
|
+
split_ball_z_size: int = 3,
|
|
55
|
+
split_ball_overlap_fraction: float = 0.8,
|
|
56
|
+
split_soma_diameter: int = 7,
|
|
57
|
+
*,
|
|
58
|
+
callback: Optional[Callable[[int], None]] = None,
|
|
59
|
+
) -> List[Cell]:
|
|
60
|
+
"""
|
|
61
|
+
Perform cell candidate detection on a 3D signal array.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
signal_array : numpy.ndarray
|
|
66
|
+
3D array representing the signal data.
|
|
67
|
+
|
|
68
|
+
start_plane : int
|
|
69
|
+
Index of the starting plane for detection.
|
|
70
|
+
|
|
71
|
+
end_plane : int
|
|
72
|
+
Index of the ending plane for detection.
|
|
73
|
+
|
|
74
|
+
voxel_sizes : Tuple[float, float, float]
|
|
75
|
+
Tuple of voxel sizes in each dimension (z, y, x).
|
|
76
|
+
|
|
77
|
+
soma_diameter : float
|
|
78
|
+
Diameter of the soma in physical units.
|
|
79
|
+
|
|
80
|
+
max_cluster_size : float
|
|
81
|
+
Maximum size of a cluster in physical units.
|
|
82
|
+
|
|
83
|
+
ball_xy_size : float
|
|
84
|
+
Size of the XY ball used for filtering in physical units.
|
|
85
|
+
|
|
86
|
+
ball_z_size : float
|
|
87
|
+
Size of the Z ball used for filtering in physical units.
|
|
88
|
+
|
|
89
|
+
ball_overlap_fraction : float
|
|
90
|
+
Fraction of overlap allowed between balls.
|
|
91
|
+
|
|
92
|
+
soma_spread_factor : float
|
|
93
|
+
Spread factor for soma size.
|
|
94
|
+
|
|
95
|
+
n_free_cpus : int
|
|
96
|
+
Number of free CPU cores available for parallel processing.
|
|
97
|
+
|
|
98
|
+
log_sigma_size : float
|
|
99
|
+
Size of the sigma for the log filter.
|
|
100
|
+
|
|
101
|
+
n_sds_above_mean_thresh : float
|
|
102
|
+
Number of standard deviations above the mean threshold.
|
|
103
|
+
|
|
104
|
+
outlier_keep : bool, optional
|
|
105
|
+
Whether to keep outliers during detection. Defaults to False.
|
|
106
|
+
|
|
107
|
+
artifact_keep : bool, optional
|
|
108
|
+
Whether to keep artifacts during detection. Defaults to False.
|
|
109
|
+
|
|
110
|
+
save_planes : bool, optional
|
|
111
|
+
Whether to save the planes during detection. Defaults to False.
|
|
112
|
+
|
|
113
|
+
plane_directory : str, optional
|
|
114
|
+
Directory path to save the planes. Defaults to None.
|
|
115
|
+
|
|
116
|
+
batch_size : int, optional
|
|
117
|
+
The number of planes to process in each batch. Defaults to 1.
|
|
118
|
+
For CPU, there's no benefit for a larger batch size. Only a memory
|
|
119
|
+
usage increase. For CUDA, the larger the batch size the better the
|
|
120
|
+
performance. Until it fills up the GPU memory - after which it
|
|
121
|
+
becomes slower.
|
|
122
|
+
|
|
123
|
+
torch_device : str, optional
|
|
124
|
+
The device on which to run the computation. By default, it's "cpu".
|
|
125
|
+
To run on a gpu, specify the PyTorch device name, such as "cuda" to
|
|
126
|
+
run on the first GPU.
|
|
127
|
+
|
|
128
|
+
callback : Callable[int], optional
|
|
129
|
+
A callback function that is called every time a plane has finished
|
|
130
|
+
being processed. Called with the plane number that has finished.
|
|
131
|
+
|
|
132
|
+
Returns
|
|
133
|
+
-------
|
|
134
|
+
List[Cell]
|
|
135
|
+
List of detected cells.
|
|
136
|
+
"""
|
|
137
|
+
start_time = datetime.now()
|
|
138
|
+
if batch_size is None:
|
|
139
|
+
if torch_device == "cpu":
|
|
140
|
+
batch_size = 4
|
|
141
|
+
else:
|
|
142
|
+
batch_size = 1
|
|
143
|
+
|
|
144
|
+
if not np.issubdtype(signal_array.dtype, np.number):
|
|
145
|
+
raise TypeError(
|
|
146
|
+
"signal_array must be a numpy datatype, but has datatype "
|
|
147
|
+
f"{signal_array.dtype}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if signal_array.ndim != 3:
|
|
151
|
+
raise ValueError("Input data must be 3D")
|
|
152
|
+
|
|
153
|
+
if end_plane < 0:
|
|
154
|
+
end_plane = len(signal_array)
|
|
155
|
+
end_plane = min(len(signal_array), end_plane)
|
|
156
|
+
|
|
157
|
+
torch_device = torch_device.lower()
|
|
158
|
+
batch_size = max(batch_size, 1)
|
|
159
|
+
# brainmapper can pass them in as str
|
|
160
|
+
voxel_sizes = list(map(float, voxel_sizes))
|
|
161
|
+
|
|
162
|
+
settings = DetectionSettings(
|
|
163
|
+
plane_shape=signal_array.shape[1:],
|
|
164
|
+
plane_original_np_dtype=signal_array.dtype,
|
|
165
|
+
voxel_sizes=voxel_sizes,
|
|
166
|
+
soma_spread_factor=soma_spread_factor,
|
|
167
|
+
soma_diameter_um=soma_diameter,
|
|
168
|
+
max_cluster_size_um3=max_cluster_size,
|
|
169
|
+
ball_xy_size_um=ball_xy_size,
|
|
170
|
+
ball_z_size_um=ball_z_size,
|
|
171
|
+
start_plane=start_plane,
|
|
172
|
+
end_plane=end_plane,
|
|
173
|
+
n_free_cpus=n_free_cpus,
|
|
174
|
+
ball_overlap_fraction=ball_overlap_fraction,
|
|
175
|
+
log_sigma_size=log_sigma_size,
|
|
176
|
+
n_sds_above_mean_thresh=n_sds_above_mean_thresh,
|
|
177
|
+
outlier_keep=outlier_keep,
|
|
178
|
+
artifact_keep=artifact_keep,
|
|
179
|
+
save_planes=save_planes,
|
|
180
|
+
plane_directory=plane_directory,
|
|
181
|
+
batch_size=batch_size,
|
|
182
|
+
torch_device=torch_device,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# replicate the settings specific to splitting, before we access anything
|
|
186
|
+
# of the original settings, causing cached properties
|
|
187
|
+
kwargs = dataclasses.asdict(settings)
|
|
188
|
+
kwargs["ball_z_size_um"] = split_ball_z_size * settings.z_pixel_size
|
|
189
|
+
kwargs["ball_xy_size_um"] = (
|
|
190
|
+
split_ball_xy_size * settings.in_plane_pixel_size
|
|
191
|
+
)
|
|
192
|
+
kwargs["ball_overlap_fraction"] = split_ball_overlap_fraction
|
|
193
|
+
kwargs["soma_diameter_um"] = (
|
|
194
|
+
split_soma_diameter * settings.in_plane_pixel_size
|
|
195
|
+
)
|
|
196
|
+
# always run on cpu because copying to gpu overhead is likely slower than
|
|
197
|
+
# any benefit for detection on smallish volumes
|
|
198
|
+
kwargs["torch_device"] = "cpu"
|
|
199
|
+
# for splitting, we only do 3d filtering. Its input is a zero volume
|
|
200
|
+
# with cell voxels marked with threshold_value. So just use float32
|
|
201
|
+
# for input because the filters will also use float(32). So there will
|
|
202
|
+
# not be need to convert the input a different dtype before passing to
|
|
203
|
+
# the filters.
|
|
204
|
+
kwargs["plane_original_np_dtype"] = np.float32
|
|
205
|
+
splitting_settings = DetectionSettings(**kwargs)
|
|
206
|
+
|
|
207
|
+
# Create 3D analysis filter
|
|
208
|
+
mp_3d_filter = VolumeFilter(settings=settings)
|
|
209
|
+
|
|
210
|
+
# Create 2D analysis filter
|
|
211
|
+
mp_tile_processor = TileProcessor(
|
|
212
|
+
plane_shape=settings.plane_shape,
|
|
213
|
+
clipping_value=settings.clipping_value,
|
|
214
|
+
threshold_value=settings.threshold_value,
|
|
215
|
+
n_sds_above_mean_thresh=n_sds_above_mean_thresh,
|
|
216
|
+
log_sigma_size=log_sigma_size,
|
|
217
|
+
soma_diameter=settings.soma_diameter,
|
|
218
|
+
torch_device=torch_device,
|
|
219
|
+
dtype=settings.filtering_dtype.__name__,
|
|
220
|
+
use_scipy=use_scipy,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
orig_n_threads = torch.get_num_threads()
|
|
224
|
+
torch.set_num_threads(settings.n_torch_comp_threads)
|
|
225
|
+
|
|
226
|
+
# process the data
|
|
227
|
+
mp_3d_filter.process(mp_tile_processor, signal_array, callback=callback)
|
|
228
|
+
cells = mp_3d_filter.get_results(splitting_settings)
|
|
229
|
+
|
|
230
|
+
torch.set_num_threads(orig_n_threads)
|
|
231
|
+
|
|
232
|
+
time_elapsed = datetime.now() - start_time
|
|
233
|
+
s = f"Detection complete. Found {len(cells)} cells in {time_elapsed}"
|
|
234
|
+
logger.debug(s)
|
|
235
|
+
print(s)
|
|
236
|
+
return cells
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from scipy.ndimage import gaussian_filter, laplace
|
|
5
|
+
from scipy.signal import medfilt2d
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@torch.jit.script
|
|
9
|
+
def normalize(
|
|
10
|
+
filtered_planes: torch.Tensor,
|
|
11
|
+
flip: bool,
|
|
12
|
+
max_value: float = 1.0,
|
|
13
|
+
) -> None:
|
|
14
|
+
"""
|
|
15
|
+
Normalizes the 3d tensor so each z-plane is independently scaled to be
|
|
16
|
+
in the [0, max_value] range. If `flip` is `True`, the sign of the tensor
|
|
17
|
+
values are flipped before any processing.
|
|
18
|
+
|
|
19
|
+
It is done to filtered_planes inplace.
|
|
20
|
+
"""
|
|
21
|
+
num_z = filtered_planes.shape[0]
|
|
22
|
+
filtered_planes_1d = filtered_planes.view(num_z, -1)
|
|
23
|
+
|
|
24
|
+
if flip:
|
|
25
|
+
filtered_planes_1d.mul_(-1)
|
|
26
|
+
|
|
27
|
+
planes_min = torch.min(filtered_planes_1d, dim=1, keepdim=True)[0]
|
|
28
|
+
filtered_planes_1d.sub_(planes_min)
|
|
29
|
+
# take max after subtraction
|
|
30
|
+
planes_max = torch.max(filtered_planes_1d, dim=1, keepdim=True)[0]
|
|
31
|
+
# if min = max = zero, divide by 1 - it'll stay zero
|
|
32
|
+
planes_max[planes_max == 0] = 1
|
|
33
|
+
filtered_planes_1d.div_(planes_max)
|
|
34
|
+
|
|
35
|
+
if max_value != 1.0:
|
|
36
|
+
# To leave room to label in the 3d detection.
|
|
37
|
+
filtered_planes_1d.mul_(max_value)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@torch.jit.script
|
|
41
|
+
def filter_for_peaks(
|
|
42
|
+
planes: torch.Tensor,
|
|
43
|
+
med_kernel: torch.Tensor,
|
|
44
|
+
gauss_kernel: torch.Tensor,
|
|
45
|
+
gauss_kernel_size: int,
|
|
46
|
+
lap_kernel: torch.Tensor,
|
|
47
|
+
device: str,
|
|
48
|
+
clipping_value: float,
|
|
49
|
+
) -> torch.Tensor:
|
|
50
|
+
"""
|
|
51
|
+
Takes the 3d z-stack and returns a new z-stack where the peaks are
|
|
52
|
+
highlighted.
|
|
53
|
+
|
|
54
|
+
It applies a median filter -> gaussian filter -> laplacian filter.
|
|
55
|
+
"""
|
|
56
|
+
filtered_planes = planes.unsqueeze(1) # ZYX -> ZCYX input, C=channels
|
|
57
|
+
|
|
58
|
+
# ------------------ median filter ------------------
|
|
59
|
+
# extracts patches to compute median over for each pixel
|
|
60
|
+
# We go from ZCYX -> ZCYX, C=1 to C=9 with C containing the elements around
|
|
61
|
+
# each Z,X,Y voxel over which we compute the median
|
|
62
|
+
# Zero padding is ok here
|
|
63
|
+
filtered_planes = F.conv2d(filtered_planes, med_kernel, padding="same")
|
|
64
|
+
# we're going back to ZCYX=Z1YX by taking median of patches in C dim
|
|
65
|
+
filtered_planes = filtered_planes.median(dim=1, keepdim=True)[0]
|
|
66
|
+
|
|
67
|
+
# ------------------ gaussian filter ------------------
|
|
68
|
+
# normalize the input data to 0-1 range. Otherwise, if the values are
|
|
69
|
+
# large, we'd need a float64 so conv result is accurate
|
|
70
|
+
normalize(filtered_planes, flip=False)
|
|
71
|
+
|
|
72
|
+
# we need to do reflection padding around the tensor for parity with scipy
|
|
73
|
+
# gaussian filtering. Scipy does reflection in a manner typically called
|
|
74
|
+
# symmetric: (dcba|abcd|dcba). Torch does it like this: (dcb|abcd|cba). So
|
|
75
|
+
# we manually do symmetric padding below
|
|
76
|
+
pad = gauss_kernel_size // 2
|
|
77
|
+
padding_mode = "reflect"
|
|
78
|
+
# if data is too small for reflect, just use constant border value
|
|
79
|
+
if pad >= filtered_planes.shape[-1] or pad >= filtered_planes.shape[-2]:
|
|
80
|
+
padding_mode = "replicate"
|
|
81
|
+
filtered_planes = F.pad(filtered_planes, (pad,) * 4, padding_mode, 0.0)
|
|
82
|
+
# We reflected torch style, so copy/shift everything by one to be symmetric
|
|
83
|
+
filtered_planes[:, :, :pad, :] = filtered_planes[
|
|
84
|
+
:, :, 1 : pad + 1, :
|
|
85
|
+
].clone()
|
|
86
|
+
filtered_planes[:, :, -pad:, :] = filtered_planes[
|
|
87
|
+
:, :, -pad - 1 : -1, :
|
|
88
|
+
].clone()
|
|
89
|
+
filtered_planes[:, :, :, :pad] = filtered_planes[
|
|
90
|
+
:, :, :, 1 : pad + 1
|
|
91
|
+
].clone()
|
|
92
|
+
filtered_planes[:, :, :, -pad:] = filtered_planes[
|
|
93
|
+
:, :, :, -pad - 1 : -1
|
|
94
|
+
].clone()
|
|
95
|
+
|
|
96
|
+
# We apply the 1D gaussian filter twice, once for Y and once for X. The
|
|
97
|
+
# filter shape passed in is 11K1 or 111K, depending on device. Where
|
|
98
|
+
# K=filter size
|
|
99
|
+
# see https://discuss.pytorch.org/t/performance-issue-for-conv2d-with-1d-
|
|
100
|
+
# filter-along-a-dim/201734/2 for the reason for the moveaxis depending
|
|
101
|
+
# on the device
|
|
102
|
+
if device == "cpu":
|
|
103
|
+
# kernel shape is 11K1. First do Y (second to last axis)
|
|
104
|
+
filtered_planes = F.conv2d(
|
|
105
|
+
filtered_planes, gauss_kernel, padding="valid"
|
|
106
|
+
)
|
|
107
|
+
# To do X, exchange X,Y axis, filter, change back. On CPU, Y (second
|
|
108
|
+
# to last) axis is faster.
|
|
109
|
+
filtered_planes = F.conv2d(
|
|
110
|
+
filtered_planes.moveaxis(-1, -2), gauss_kernel, padding="valid"
|
|
111
|
+
).moveaxis(-1, -2)
|
|
112
|
+
else:
|
|
113
|
+
# kernel shape is 111K
|
|
114
|
+
# First do Y (second to last axis). Exchange X,Y axis, filter, change
|
|
115
|
+
# back. On CUDA, X (last) axis is faster.
|
|
116
|
+
filtered_planes = F.conv2d(
|
|
117
|
+
filtered_planes.moveaxis(-1, -2), gauss_kernel, padding="valid"
|
|
118
|
+
).moveaxis(-1, -2)
|
|
119
|
+
# now do X, last axis
|
|
120
|
+
filtered_planes = F.conv2d(
|
|
121
|
+
filtered_planes, gauss_kernel, padding="valid"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# ------------------ laplacian filter ------------------
|
|
125
|
+
# it's a 2d filter. Need to pad using symmetric for scipy parity. But,
|
|
126
|
+
# torch doesn't have it, and we used a kernel of size 3, so for padding of
|
|
127
|
+
# 1, replicate == symmetric. That's enough for parity with past scipy. If
|
|
128
|
+
# we change kernel size in the future, we may have to do as above
|
|
129
|
+
padding = lap_kernel.shape[-1] // 2
|
|
130
|
+
filtered_planes = F.pad(filtered_planes, (padding,) * 4, "replicate")
|
|
131
|
+
filtered_planes = F.conv2d(filtered_planes, lap_kernel, padding="valid")
|
|
132
|
+
|
|
133
|
+
# we don't need the channel axis
|
|
134
|
+
filtered_planes = filtered_planes[:, 0, :, :]
|
|
135
|
+
|
|
136
|
+
# scale back to full scale, filtered values are negative so flip
|
|
137
|
+
normalize(filtered_planes, flip=True, max_value=clipping_value)
|
|
138
|
+
return filtered_planes
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class PeakEnhancer:
|
|
142
|
+
"""
|
|
143
|
+
A class that filters each plane in a z-stack such that peaks are
|
|
144
|
+
visualized.
|
|
145
|
+
|
|
146
|
+
It uses a series of 2D filters of median -> gaussian ->
|
|
147
|
+
laplacian. Then normalizes each plane to be between [0, clipping_value].
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
torch_device: str
|
|
152
|
+
The device on which the data and processing occurs on. Can be e.g.
|
|
153
|
+
"cpu", "cuda" etc. Any data passed to the filter must be on this
|
|
154
|
+
device. Returned data will also be on this device.
|
|
155
|
+
dtype : torch.dtype
|
|
156
|
+
The data-type of the input planes and the type to use internally.
|
|
157
|
+
E.g. `torch.float32`.
|
|
158
|
+
clipping_value : int
|
|
159
|
+
The value such that after normalizing, the max value will be this
|
|
160
|
+
clipping_value.
|
|
161
|
+
laplace_gaussian_sigma : float
|
|
162
|
+
Size of the sigma for the gaussian filter.
|
|
163
|
+
use_scipy : bool
|
|
164
|
+
If running on the CPU whether to use the scipy filters or the same
|
|
165
|
+
pytorch filters used on CUDA. Scipy filters can be faster.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
# binary kernel that generates square patches for each pixel so we can find
|
|
169
|
+
# the median around the pixel
|
|
170
|
+
med_kernel: torch.Tensor
|
|
171
|
+
|
|
172
|
+
# gaussian 1D kernel with kernel/weight shape 11K1 or 111K, depending
|
|
173
|
+
# on device. Where K=filter size
|
|
174
|
+
gauss_kernel: torch.Tensor
|
|
175
|
+
|
|
176
|
+
# 2D laplacian kernel with kernel/weight shape KxK. Where
|
|
177
|
+
# K=filter size
|
|
178
|
+
lap_kernel: torch.Tensor
|
|
179
|
+
|
|
180
|
+
# the value such that after normalizing, the max value will be this
|
|
181
|
+
# clipping_value
|
|
182
|
+
clipping_value: float
|
|
183
|
+
|
|
184
|
+
# sigma value for gaussian filter
|
|
185
|
+
laplace_gaussian_sigma: float
|
|
186
|
+
|
|
187
|
+
# the torch device to run on. E.g. cpu/cuda.
|
|
188
|
+
torch_device: str
|
|
189
|
+
|
|
190
|
+
# when running on CPU whether to use pytorch or scipy for filters
|
|
191
|
+
use_scipy: bool
|
|
192
|
+
|
|
193
|
+
median_filter_size: int = 3
|
|
194
|
+
"""
|
|
195
|
+
The median filter size in x/y direction.
|
|
196
|
+
|
|
197
|
+
**Must** be odd.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
torch_device: str,
|
|
203
|
+
dtype: torch.dtype,
|
|
204
|
+
clipping_value: float,
|
|
205
|
+
laplace_gaussian_sigma: float,
|
|
206
|
+
use_scipy: bool,
|
|
207
|
+
):
|
|
208
|
+
super().__init__()
|
|
209
|
+
self.torch_device = torch_device.lower()
|
|
210
|
+
self.clipping_value = clipping_value
|
|
211
|
+
self.laplace_gaussian_sigma = laplace_gaussian_sigma
|
|
212
|
+
self.use_scipy = use_scipy
|
|
213
|
+
|
|
214
|
+
# all these kernels are odd in size
|
|
215
|
+
self.med_kernel = self._get_median_kernel(torch_device, dtype)
|
|
216
|
+
self.gauss_kernel = self._get_gaussian_kernel(
|
|
217
|
+
torch_device, dtype, laplace_gaussian_sigma
|
|
218
|
+
)
|
|
219
|
+
self.lap_kernel = self._get_laplacian_kernel(torch_device, dtype)
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def gaussian_filter_size(self) -> int:
|
|
223
|
+
"""
|
|
224
|
+
The gaussian filter 1d size.
|
|
225
|
+
|
|
226
|
+
It is odd.
|
|
227
|
+
"""
|
|
228
|
+
return 2 * int(round(4 * self.laplace_gaussian_sigma)) + 1
|
|
229
|
+
|
|
230
|
+
def _get_median_kernel(
|
|
231
|
+
self, torch_device: str, dtype: torch.dtype
|
|
232
|
+
) -> torch.Tensor:
|
|
233
|
+
"""
|
|
234
|
+
Gets a median patch generator kernel, already on the correct
|
|
235
|
+
device.
|
|
236
|
+
|
|
237
|
+
Based on how kornia does it for median filtering.
|
|
238
|
+
"""
|
|
239
|
+
# must be odd kernel
|
|
240
|
+
kernel_n = self.median_filter_size
|
|
241
|
+
if not (kernel_n % 2):
|
|
242
|
+
raise ValueError("The median filter size must be odd")
|
|
243
|
+
|
|
244
|
+
# extract patches to compute median over for each pixel. When passing
|
|
245
|
+
# input we go from ZCYX -> ZCYX, C=1 to C=9 and containing the elements
|
|
246
|
+
# around each Z,X,Y over which we can then compute the median
|
|
247
|
+
window_range = kernel_n * kernel_n # e.g. 3x3
|
|
248
|
+
kernel = torch.zeros(
|
|
249
|
+
(window_range, window_range), device=torch_device, dtype=dtype
|
|
250
|
+
)
|
|
251
|
+
idx = torch.arange(window_range, device=torch_device)
|
|
252
|
+
# diagonal of e.g. 9x9 is 1
|
|
253
|
+
kernel[idx, idx] = 1.0
|
|
254
|
+
# out channels, in channels, n*y, n*x. The kernel collects all the 3x3
|
|
255
|
+
# elements around a pixel, using a binary mask for each element, as a
|
|
256
|
+
# separate channel. So we go from 1 to 9 channels in the output
|
|
257
|
+
kernel = kernel.view(window_range, 1, kernel_n, kernel_n)
|
|
258
|
+
|
|
259
|
+
return kernel
|
|
260
|
+
|
|
261
|
+
def _get_gaussian_kernel(
|
|
262
|
+
self,
|
|
263
|
+
torch_device: str,
|
|
264
|
+
dtype: torch.dtype,
|
|
265
|
+
laplace_gaussian_sigma: float,
|
|
266
|
+
) -> torch.Tensor:
|
|
267
|
+
"""Gets the 1D gaussian kernel used to filter the data."""
|
|
268
|
+
# we do 2 1D filters, once on each y, x dim.
|
|
269
|
+
# shape of kernel will be 11K1 with dims Z, C, Y, X. C=1, Z is expanded
|
|
270
|
+
# to number of z during filtering.
|
|
271
|
+
kernel_size = self.gaussian_filter_size
|
|
272
|
+
|
|
273
|
+
# to get the values of a 1D gaussian kernel, we pass a single impulse
|
|
274
|
+
# data through the filter, which recovers the filter values. We do this
|
|
275
|
+
# because scipy doesn't make their kernel available in public API and
|
|
276
|
+
# we want parity with scipy filtering
|
|
277
|
+
impulse = np.zeros(kernel_size)
|
|
278
|
+
# the impulse needs to be to the left of center
|
|
279
|
+
impulse[kernel_size // 2] = 1
|
|
280
|
+
kernel = gaussian_filter(
|
|
281
|
+
impulse, laplace_gaussian_sigma, mode="constant", cval=0
|
|
282
|
+
)
|
|
283
|
+
# kernel should be fully symmetric
|
|
284
|
+
assert kernel[0] == kernel[-1]
|
|
285
|
+
gauss_kernel = torch.from_numpy(kernel).type(dtype).to(torch_device)
|
|
286
|
+
|
|
287
|
+
# default shape is (y, x) with y axis filtered only - we transpose
|
|
288
|
+
# input to filter on x
|
|
289
|
+
gauss_kernel = gauss_kernel.view(1, 1, -1, 1)
|
|
290
|
+
|
|
291
|
+
# see https://discuss.pytorch.org/t/performance-issue-for-conv2d-
|
|
292
|
+
# with-1d-filter-along-a-dim/201734. Conv2d is faster on a specific dim
|
|
293
|
+
# for 1D filters depending on CPU/CUDA. See also filter_for_peaks
|
|
294
|
+
# on CPU, we only do conv2d on the (1st) dim
|
|
295
|
+
if torch_device != "cpu":
|
|
296
|
+
# on CUDA, we only filter on the x dim, flipping input to filter y
|
|
297
|
+
gauss_kernel = gauss_kernel.view(1, 1, 1, -1)
|
|
298
|
+
|
|
299
|
+
return gauss_kernel
|
|
300
|
+
|
|
301
|
+
def _get_laplacian_kernel(
|
|
302
|
+
self, torch_device: str, dtype: torch.dtype
|
|
303
|
+
) -> torch.Tensor:
|
|
304
|
+
"""Gets a 2d laplacian kernel, based on scipy's laplace."""
|
|
305
|
+
# for parity with scipy, scipy computes the laplacian with default
|
|
306
|
+
# parameters and kernel size 3 using filter coefficients [1, -2, 1].
|
|
307
|
+
# Each filtered pixel is the sum of the filter around the pixel
|
|
308
|
+
# vertically and horizontally. We can do it in 2d at once with
|
|
309
|
+
# coefficients below (faster than 2x1D for such small filter)
|
|
310
|
+
return torch.as_tensor(
|
|
311
|
+
[[0, 1, 0], [1, -4, 1], [0, 1, 0]],
|
|
312
|
+
dtype=dtype,
|
|
313
|
+
device=torch_device,
|
|
314
|
+
).view(1, 1, 3, 3)
|
|
315
|
+
|
|
316
|
+
def enhance_peaks(self, planes: torch.Tensor) -> torch.Tensor:
|
|
317
|
+
"""
|
|
318
|
+
Applies the filtering and normalization to the 3d z-stack (not inplace)
|
|
319
|
+
and returns the filtered z-stack.
|
|
320
|
+
"""
|
|
321
|
+
if self.torch_device == "cpu" and self.use_scipy:
|
|
322
|
+
filtered_planes = planes.clone()
|
|
323
|
+
for i in range(planes.shape[0]):
|
|
324
|
+
img = planes[i, :, :].numpy()
|
|
325
|
+
img = medfilt2d(img)
|
|
326
|
+
img = gaussian_filter(img, self.laplace_gaussian_sigma)
|
|
327
|
+
img = laplace(img)
|
|
328
|
+
filtered_planes[i, :, :] = torch.from_numpy(img)
|
|
329
|
+
|
|
330
|
+
# laplace makes values negative so flip
|
|
331
|
+
normalize(
|
|
332
|
+
filtered_planes,
|
|
333
|
+
flip=True,
|
|
334
|
+
max_value=self.clipping_value,
|
|
335
|
+
)
|
|
336
|
+
return filtered_planes
|
|
337
|
+
|
|
338
|
+
filtered_planes = filter_for_peaks(
|
|
339
|
+
planes,
|
|
340
|
+
self.med_kernel,
|
|
341
|
+
self.gauss_kernel,
|
|
342
|
+
self.gaussian_filter_size,
|
|
343
|
+
self.lap_kernel,
|
|
344
|
+
self.torch_device,
|
|
345
|
+
self.clipping_value,
|
|
346
|
+
)
|
|
347
|
+
return filtered_planes
|